|
123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694 |
- '''
- The MIT License (MIT)
- Copyright (c) 2013 Dave P.
- '''
- import sys
- VER = sys.version_info[0]
- if VER >= 3:
- import socketserver
- from http.server import BaseHTTPRequestHandler
- from io import StringIO, BytesIO
- else:
- import SocketServer
- from BaseHTTPServer import BaseHTTPRequestHandler
- from StringIO import StringIO
-
- import hashlib
- import base64
- import socket
- import struct
- import ssl
- import errno
- import codecs
- from collections import deque
- from select import select
-
- __all__ = ['WebSocket',
- 'SimpleWebSocketServer',
- 'SimpleSSLWebSocketServer']
-
- def _check_unicode(val):
- if VER >= 3:
- return isinstance(val, str)
- else:
- return isinstance(val, unicode)
-
- class HTTPRequest(BaseHTTPRequestHandler):
- def __init__(self, request_text):
- if VER >= 3:
- self.rfile = BytesIO(request_text)
- else:
- self.rfile = StringIO(request_text)
- self.raw_requestline = self.rfile.readline()
- self.error_code = self.error_message = None
- self.parse_request()
-
- _VALID_STATUS_CODES = [1000, 1001, 1002, 1003, 1007, 1008,
- 1009, 1010, 1011, 3000, 3999, 4000, 4999]
-
- HANDSHAKE_STR = (
- "HTTP/1.1 101 Switching Protocols\r\n"
- "Upgrade: WebSocket\r\n"
- "Connection: Upgrade\r\n"
- "Sec-WebSocket-Accept: %(acceptstr)s\r\n\r\n"
- )
-
- GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
-
- STREAM = 0x0
- TEXT = 0x1
- BINARY = 0x2
- CLOSE = 0x8
- PING = 0x9
- PONG = 0xA
-
- HEADERB1 = 1
- HEADERB2 = 3
- LENGTHSHORT = 4
- LENGTHLONG = 5
- MASK = 6
- PAYLOAD = 7
-
- MAXHEADER = 65536
- MAXPAYLOAD = 33554432
-
- class WebSocket(object):
-
- def __init__(self, server, sock, address):
- self.server = server
- self.client = sock
- self.address = address
-
- self.handshaked = False
- self.headerbuffer = bytearray()
- self.headertoread = 2048
-
- self.fin = 0
- self.data = bytearray()
- self.opcode = 0
- self.hasmask = 0
- self.maskarray = None
- self.length = 0
- self.lengtharray = None
- self.index = 0
- self.request = None
- self.usingssl = False
-
- self.frag_start = False
- self.frag_type = BINARY
- self.frag_buffer = None
- self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict')
- self.closed = False
- self.sendq = deque()
-
- self.state = HEADERB1
-
- # restrict the size of header and payload for security reasons
- self.maxheader = MAXHEADER
- self.maxpayload = MAXPAYLOAD
-
- def handleMessage(self):
- """
- Called when websocket frame is received.
- To access the frame data call self.data.
-
- If the frame is Text then self.data is a unicode object.
- If the frame is Binary then self.data is a bytearray object.
- """
- pass
-
- def handleConnected(self):
- """
- Called when a websocket client connects to the server.
- """
- pass
-
- def handleClose(self):
- """
- Called when a websocket server gets a Close frame from a client.
- """
- pass
-
- def _handlePacket(self):
- if self.opcode == CLOSE:
- pass
- elif self.opcode == STREAM:
- pass
- elif self.opcode == TEXT:
- pass
- elif self.opcode == BINARY:
- pass
- elif self.opcode == PONG or self.opcode == PING:
- if len(self.data) > 125:
- raise Exception('control frame length can not be > 125')
- else:
- # unknown or reserved opcode so just close
- raise Exception('unknown opcode')
-
- if self.opcode == CLOSE:
- status = 1000
- reason = u''
- length = len(self.data)
-
- if length == 0:
- pass
- elif length >= 2:
- status = struct.unpack_from('!H', self.data[:2])[0]
- reason = self.data[2:]
-
- if status not in _VALID_STATUS_CODES:
- status = 1002
-
- if len(reason) > 0:
- try:
- reason = reason.decode('utf8', errors='strict')
- except:
- status = 1002
- else:
- status = 1002
-
- self.close(status, reason)
- return
-
- elif self.fin == 0:
- if self.opcode != STREAM:
- if self.opcode == PING or self.opcode == PONG:
- raise Exception('control messages can not be fragmented')
-
- self.frag_type = self.opcode
- self.frag_start = True
- self.frag_decoder.reset()
-
- if self.frag_type == TEXT:
- self.frag_buffer = []
- utf_str = self.frag_decoder.decode(self.data, final = False)
- if utf_str:
- self.frag_buffer.append(utf_str)
- else:
- self.frag_buffer = bytearray()
- self.frag_buffer.extend(self.data)
-
- else:
- if self.frag_start is False:
- raise Exception('fragmentation protocol error')
-
- if self.frag_type == TEXT:
- utf_str = self.frag_decoder.decode(self.data, final = False)
- if utf_str:
- self.frag_buffer.append(utf_str)
- else:
- self.frag_buffer.extend(self.data)
-
- else:
- if self.opcode == STREAM:
- if self.frag_start is False:
- raise Exception('fragmentation protocol error')
-
- if self.frag_type == TEXT:
- utf_str = self.frag_decoder.decode(self.data, final = True)
- self.frag_buffer.append(utf_str)
- self.data = u''.join(self.frag_buffer)
- else:
- self.frag_buffer.extend(self.data)
- self.data = self.frag_buffer
-
- self.handleMessage()
-
- self.frag_decoder.reset()
- self.frag_type = BINARY
- self.frag_start = False
- self.frag_buffer = None
-
- elif self.opcode == PING:
- self._sendMessage(False, PONG, self.data)
-
- elif self.opcode == PONG:
- pass
-
- else:
- if self.frag_start is True:
- raise Exception('fragmentation protocol error')
-
- if self.opcode == TEXT:
- try:
- self.data = self.data.decode('utf8', errors='strict')
- except Exception as exp:
- raise Exception('invalid utf-8 payload')
-
- self.handleMessage()
-
-
- def _handleData(self):
- # do the HTTP header and handshake
- if self.handshaked is False:
-
- data = self.client.recv(self.headertoread)
- if not data:
- raise Exception('remote socket closed')
-
- else:
- # accumulate
- self.headerbuffer.extend(data)
-
- if len(self.headerbuffer) >= self.maxheader:
- raise Exception('header exceeded allowable size')
-
- # indicates end of HTTP header
- if b'\r\n\r\n' in self.headerbuffer:
- self.request = HTTPRequest(self.headerbuffer)
-
- # handshake rfc 6455
- try:
- key = self.request.headers['Sec-WebSocket-Key']
- k = key.encode('ascii') + GUID_STR.encode('ascii')
- k_s = base64.b64encode(hashlib.sha1(k).digest()).decode('ascii')
- hStr = HANDSHAKE_STR % {'acceptstr': k_s}
- self.sendq.append((BINARY, hStr.encode('ascii')))
- self.handshaked = True
- self.handleConnected()
- except Exception as e:
- raise Exception('handshake failed: %s', str(e))
-
- # else do normal data
- else:
- data = self.client.recv(8192)
- if not data:
- raise Exception("remote socket closed")
-
- if VER >= 3:
- for d in data:
- self._parseMessage(d)
- else:
- for d in data:
- self._parseMessage(ord(d))
-
- def close(self, status = 1000, reason = u''):
- """
- Send Close frame to the client. The underlying socket is only closed
- when the client acknowledges the Close frame.
-
- status is the closing identifier.
- reason is the reason for the close.
- """
- try:
- if self.closed is False:
- close_msg = bytearray()
- close_msg.extend(struct.pack("!H", status))
- if _check_unicode(reason):
- close_msg.extend(reason.encode('utf-8'))
- else:
- close_msg.extend(reason)
-
- self._sendMessage(False, CLOSE, close_msg)
-
- finally:
- self.closed = True
-
-
- def _sendBuffer(self, buff):
- size = len(buff)
- tosend = size
- already_sent = 0
-
- while tosend > 0:
- try:
- # i should be able to send a bytearray
- sent = self.client.send(buff[already_sent:])
- if sent == 0:
- raise RuntimeError('socket connection broken')
-
- already_sent += sent
- tosend -= sent
-
- except socket.error as e:
- # if we have full buffers then wait for them to drain and try again
- if e.errno in [errno.EAGAIN, errno.EWOULDBLOCK]:
- return buff[already_sent:]
- else:
- raise e
-
- return None
-
- def sendFragmentStart(self, data):
- """
- Send the start of a data fragment stream to a websocket client.
- Subsequent data should be sent using sendFragment().
- A fragment stream is completed when sendFragmentEnd() is called.
-
- If data is a unicode object then the frame is sent as Text.
- If the data is a bytearray object then the frame is sent as Binary.
- """
- opcode = BINARY
- if _check_unicode(data):
- opcode = TEXT
- self._sendMessage(True, opcode, data)
-
- def sendFragment(self, data):
- """
- see sendFragmentStart()
-
- If data is a unicode object then the frame is sent as Text.
- If the data is a bytearray object then the frame is sent as Binary.
- """
- self._sendMessage(True, STREAM, data)
-
- def sendFragmentEnd(self, data):
- """
- see sendFragmentEnd()
-
- If data is a unicode object then the frame is sent as Text.
- If the data is a bytearray object then the frame is sent as Binary.
- """
- self._sendMessage(False, STREAM, data)
-
- def sendMessage(self, data):
- """
- Send websocket data frame to the client.
-
- If data is a unicode object then the frame is sent as Text.
- If the data is a bytearray object then the frame is sent as Binary.
- """
- opcode = BINARY
- if _check_unicode(data):
- opcode = TEXT
- self._sendMessage(False, opcode, data)
-
-
- def _sendMessage(self, fin, opcode, data):
-
- payload = bytearray()
-
- b1 = 0
- b2 = 0
- if fin is False:
- b1 |= 0x80
- b1 |= opcode
-
- if _check_unicode(data):
- data = data.encode('utf-8')
-
- length = len(data)
- payload.append(b1)
-
- if length <= 125:
- b2 |= length
- payload.append(b2)
-
- elif length >= 126 and length <= 65535:
- b2 |= 126
- payload.append(b2)
- payload.extend(struct.pack("!H", length))
-
- else:
- b2 |= 127
- payload.append(b2)
- payload.extend(struct.pack("!Q", length))
-
- if length > 0:
- payload.extend(data)
-
- self.sendq.append((opcode, payload))
-
-
- def _parseMessage(self, byte):
- # read in the header
- if self.state == HEADERB1:
-
- self.fin = byte & 0x80
- self.opcode = byte & 0x0F
- self.state = HEADERB2
-
- self.index = 0
- self.length = 0
- self.lengtharray = bytearray()
- self.data = bytearray()
-
- rsv = byte & 0x70
- if rsv != 0:
- raise Exception('RSV bit must be 0')
-
- elif self.state == HEADERB2:
- mask = byte & 0x80
- length = byte & 0x7F
-
- if self.opcode == PING and length > 125:
- raise Exception('ping packet is too large')
-
- if mask == 128:
- self.hasmask = True
- else:
- self.hasmask = False
-
- if length <= 125:
- self.length = length
-
- # if we have a mask we must read it
- if self.hasmask is True:
- self.maskarray = bytearray()
- self.state = MASK
- else:
- # if there is no mask and no payload we are done
- if self.length <= 0:
- try:
- self._handlePacket()
- finally:
- self.state = self.HEADERB1
- self.data = bytearray()
-
- # we have no mask and some payload
- else:
- #self.index = 0
- self.data = bytearray()
- self.state = PAYLOAD
-
- elif length == 126:
- self.lengtharray = bytearray()
- self.state = LENGTHSHORT
-
- elif length == 127:
- self.lengtharray = bytearray()
- self.state = LENGTHLONG
-
-
- elif self.state == LENGTHSHORT:
- self.lengtharray.append(byte)
-
- if len(self.lengtharray) > 2:
- raise Exception('short length exceeded allowable size')
-
- if len(self.lengtharray) == 2:
- self.length = struct.unpack_from('!H', self.lengtharray)[0]
-
- if self.hasmask is True:
- self.maskarray = bytearray()
- self.state = MASK
- else:
- # if there is no mask and no payload we are done
- if self.length <= 0:
- try:
- self._handlePacket()
- finally:
- self.state = HEADERB1
- self.data = bytearray()
-
- # we have no mask and some payload
- else:
- #self.index = 0
- self.data = bytearray()
- self.state = PAYLOAD
-
- elif self.state == LENGTHLONG:
-
- self.lengtharray.append(byte)
-
- if len(self.lengtharray) > 8:
- raise Exception('long length exceeded allowable size')
-
- if len(self.lengtharray) == 8:
- self.length = struct.unpack_from('!Q', self.lengtharray)[0]
-
- if self.hasmask is True:
- self.maskarray = bytearray()
- self.state = MASK
- else:
- # if there is no mask and no payload we are done
- if self.length <= 0:
- try:
- self._handlePacket()
- finally:
- self.state = HEADERB1
- self.data = bytearray()
-
- # we have no mask and some payload
- else:
- #self.index = 0
- self.data = bytearray()
- self.state = PAYLOAD
-
- # MASK STATE
- elif self.state == MASK:
- self.maskarray.append(byte)
-
- if len(self.maskarray) > 4:
- raise Exception('mask exceeded allowable size')
-
- if len(self.maskarray) == 4:
- # if there is no mask and no payload we are done
- if self.length <= 0:
- try:
- self._handlePacket()
- finally:
- self.state = HEADERB1
- self.data = bytearray()
-
- # we have no mask and some payload
- else:
- #self.index = 0
- self.data = bytearray()
- self.state = PAYLOAD
-
- # PAYLOAD STATE
- elif self.state == PAYLOAD:
- if self.hasmask is True:
- self.data.append( byte ^ self.maskarray[self.index % 4] )
- else:
- self.data.append( byte )
-
- # if length exceeds allowable size then we except and remove the connection
- if len(self.data) >= self.maxpayload:
- raise Exception('payload exceeded allowable size')
-
- # check if we have processed length bytes; if so we are done
- if (self.index+1) == self.length:
- try:
- self._handlePacket()
- finally:
- #self.index = 0
- self.state = HEADERB1
- self.data = bytearray()
- else:
- self.index += 1
-
-
- class SimpleWebSocketServer(object):
- def __init__(self, host, port, websocketclass, selectInterval = 0.1):
- self.websocketclass = websocketclass
- self.serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
- self.serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
- self.serversocket.bind((host, port))
- self.serversocket.listen(5)
- self.selectInterval = selectInterval
- self.connections = {}
- self.listeners = [self.serversocket]
-
- def _decorateSocket(self, sock):
- return sock
-
- def _constructWebSocket(self, sock, address):
- return self.websocketclass(self, sock, address)
-
- def close(self):
- self.serversocket.close()
-
- for desc, conn in self.connections.items():
- conn.close()
- conn.handleClose()
-
-
- def serveforever(self):
- while True:
- writers = []
- for fileno in self.listeners:
- if fileno == self.serversocket:
- continue
- client = self.connections[fileno]
- if client.sendq:
- writers.append(fileno)
-
- if self.selectInterval:
- rList, wList, xList = select(self.listeners, writers, self.listeners, self.selectInterval)
- else:
- rList, wList, xList = select(self.listeners, writers, self.listeners)
-
- for ready in wList:
- client = self.connections[ready]
- try:
- while client.sendq:
- opcode, payload = client.sendq.popleft()
- remaining = client._sendBuffer(payload)
- if remaining is not None:
- client.sendq.appendleft((opcode, remaining))
- break
- else:
- if opcode == CLOSE:
- raise Exception('received client close')
-
- except Exception as n:
- client.client.close()
- client.handleClose()
- del self.connections[ready]
- self.listeners.remove(ready)
-
- for ready in rList:
- if ready == self.serversocket:
- try:
- sock, address = self.serversocket.accept()
- newsock = self._decorateSocket(sock)
- newsock.setblocking(0)
- fileno = newsock.fileno()
- self.connections[fileno] = self._constructWebSocket(newsock, address)
- self.listeners.append(fileno)
- except Exception as n:
- if sock is not None:
- sock.close()
- else:
- if ready not in self.connections:
- continue
- client = self.connections[ready]
- try:
- client._handleData()
- except Exception as n:
- client.client.close()
- client.handleClose()
- del self.connections[ready]
- self.listeners.remove(ready)
-
- for failed in xList:
- if failed == self.serversocket:
- self.close()
- raise Exception('server socket failed')
- else:
- if failed not in self.connections:
- continue
- client = self.connections[failed]
- client.client.close()
- client.handleClose()
- del self.connections[failed]
- self.listeners.remove(failed)
-
-
- class SimpleSSLWebSocketServer(SimpleWebSocketServer):
-
- def __init__(self, host, port, websocketclass, certfile,
- keyfile, version = ssl.PROTOCOL_TLSv1, selectInterval = 0.1):
-
- SimpleWebSocketServer.__init__(self, host, port,
- websocketclass, selectInterval)
-
- self.context = ssl.SSLContext(version)
- self.context.load_cert_chain(certfile, keyfile)
-
- def close(self):
- super(SimpleSSLWebSocketServer, self).close()
-
- def _decorateSocket(self, sock):
- sslsock = self.context.wrap_socket(sock, server_side=True)
- return sslsock
-
- def _constructWebSocket(self, sock, address):
- ws = self.websocketclass(self, sock, address)
- ws.usingssl = True
- return ws
-
- def serveforever(self):
- super(SimpleSSLWebSocketServer, self).serveforever()
|