Anders and Briegel in Python
You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

695 lines
21KB

  1. '''
  2. The MIT License (MIT)
  3. Copyright (c) 2013 Dave P.
  4. '''
  5. import sys
  6. VER = sys.version_info[0]
  7. if VER >= 3:
  8. import socketserver
  9. from http.server import BaseHTTPRequestHandler
  10. from io import StringIO, BytesIO
  11. else:
  12. import SocketServer
  13. from BaseHTTPServer import BaseHTTPRequestHandler
  14. from StringIO import StringIO
  15. import hashlib
  16. import base64
  17. import socket
  18. import struct
  19. import ssl
  20. import errno
  21. import codecs
  22. from collections import deque
  23. from select import select
  24. __all__ = ['WebSocket',
  25. 'SimpleWebSocketServer',
  26. 'SimpleSSLWebSocketServer']
  27. def _check_unicode(val):
  28. if VER >= 3:
  29. return isinstance(val, str)
  30. else:
  31. return isinstance(val, unicode)
  32. class HTTPRequest(BaseHTTPRequestHandler):
  33. def __init__(self, request_text):
  34. if VER >= 3:
  35. self.rfile = BytesIO(request_text)
  36. else:
  37. self.rfile = StringIO(request_text)
  38. self.raw_requestline = self.rfile.readline()
  39. self.error_code = self.error_message = None
  40. self.parse_request()
  41. _VALID_STATUS_CODES = [1000, 1001, 1002, 1003, 1007, 1008,
  42. 1009, 1010, 1011, 3000, 3999, 4000, 4999]
  43. HANDSHAKE_STR = (
  44. "HTTP/1.1 101 Switching Protocols\r\n"
  45. "Upgrade: WebSocket\r\n"
  46. "Connection: Upgrade\r\n"
  47. "Sec-WebSocket-Accept: %(acceptstr)s\r\n\r\n"
  48. )
  49. GUID_STR = '258EAFA5-E914-47DA-95CA-C5AB0DC85B11'
  50. STREAM = 0x0
  51. TEXT = 0x1
  52. BINARY = 0x2
  53. CLOSE = 0x8
  54. PING = 0x9
  55. PONG = 0xA
  56. HEADERB1 = 1
  57. HEADERB2 = 3
  58. LENGTHSHORT = 4
  59. LENGTHLONG = 5
  60. MASK = 6
  61. PAYLOAD = 7
  62. MAXHEADER = 65536
  63. MAXPAYLOAD = 33554432
  64. class WebSocket(object):
  65. def __init__(self, server, sock, address):
  66. self.server = server
  67. self.client = sock
  68. self.address = address
  69. self.handshaked = False
  70. self.headerbuffer = bytearray()
  71. self.headertoread = 2048
  72. self.fin = 0
  73. self.data = bytearray()
  74. self.opcode = 0
  75. self.hasmask = 0
  76. self.maskarray = None
  77. self.length = 0
  78. self.lengtharray = None
  79. self.index = 0
  80. self.request = None
  81. self.usingssl = False
  82. self.frag_start = False
  83. self.frag_type = BINARY
  84. self.frag_buffer = None
  85. self.frag_decoder = codecs.getincrementaldecoder('utf-8')(errors='strict')
  86. self.closed = False
  87. self.sendq = deque()
  88. self.state = HEADERB1
  89. # restrict the size of header and payload for security reasons
  90. self.maxheader = MAXHEADER
  91. self.maxpayload = MAXPAYLOAD
  92. def handleMessage(self):
  93. """
  94. Called when websocket frame is received.
  95. To access the frame data call self.data.
  96. If the frame is Text then self.data is a unicode object.
  97. If the frame is Binary then self.data is a bytearray object.
  98. """
  99. pass
  100. def handleConnected(self):
  101. """
  102. Called when a websocket client connects to the server.
  103. """
  104. pass
  105. def handleClose(self):
  106. """
  107. Called when a websocket server gets a Close frame from a client.
  108. """
  109. pass
  110. def _handlePacket(self):
  111. if self.opcode == CLOSE:
  112. pass
  113. elif self.opcode == STREAM:
  114. pass
  115. elif self.opcode == TEXT:
  116. pass
  117. elif self.opcode == BINARY:
  118. pass
  119. elif self.opcode == PONG or self.opcode == PING:
  120. if len(self.data) > 125:
  121. raise Exception('control frame length can not be > 125')
  122. else:
  123. # unknown or reserved opcode so just close
  124. raise Exception('unknown opcode')
  125. if self.opcode == CLOSE:
  126. status = 1000
  127. reason = u''
  128. length = len(self.data)
  129. if length == 0:
  130. pass
  131. elif length >= 2:
  132. status = struct.unpack_from('!H', self.data[:2])[0]
  133. reason = self.data[2:]
  134. if status not in _VALID_STATUS_CODES:
  135. status = 1002
  136. if len(reason) > 0:
  137. try:
  138. reason = reason.decode('utf8', errors='strict')
  139. except:
  140. status = 1002
  141. else:
  142. status = 1002
  143. self.close(status, reason)
  144. return
  145. elif self.fin == 0:
  146. if self.opcode != STREAM:
  147. if self.opcode == PING or self.opcode == PONG:
  148. raise Exception('control messages can not be fragmented')
  149. self.frag_type = self.opcode
  150. self.frag_start = True
  151. self.frag_decoder.reset()
  152. if self.frag_type == TEXT:
  153. self.frag_buffer = []
  154. utf_str = self.frag_decoder.decode(self.data, final = False)
  155. if utf_str:
  156. self.frag_buffer.append(utf_str)
  157. else:
  158. self.frag_buffer = bytearray()
  159. self.frag_buffer.extend(self.data)
  160. else:
  161. if self.frag_start is False:
  162. raise Exception('fragmentation protocol error')
  163. if self.frag_type == TEXT:
  164. utf_str = self.frag_decoder.decode(self.data, final = False)
  165. if utf_str:
  166. self.frag_buffer.append(utf_str)
  167. else:
  168. self.frag_buffer.extend(self.data)
  169. else:
  170. if self.opcode == STREAM:
  171. if self.frag_start is False:
  172. raise Exception('fragmentation protocol error')
  173. if self.frag_type == TEXT:
  174. utf_str = self.frag_decoder.decode(self.data, final = True)
  175. self.frag_buffer.append(utf_str)
  176. self.data = u''.join(self.frag_buffer)
  177. else:
  178. self.frag_buffer.extend(self.data)
  179. self.data = self.frag_buffer
  180. self.handleMessage()
  181. self.frag_decoder.reset()
  182. self.frag_type = BINARY
  183. self.frag_start = False
  184. self.frag_buffer = None
  185. elif self.opcode == PING:
  186. self._sendMessage(False, PONG, self.data)
  187. elif self.opcode == PONG:
  188. pass
  189. else:
  190. if self.frag_start is True:
  191. raise Exception('fragmentation protocol error')
  192. if self.opcode == TEXT:
  193. try:
  194. self.data = self.data.decode('utf8', errors='strict')
  195. except Exception as exp:
  196. raise Exception('invalid utf-8 payload')
  197. self.handleMessage()
  198. def _handleData(self):
  199. # do the HTTP header and handshake
  200. if self.handshaked is False:
  201. data = self.client.recv(self.headertoread)
  202. if not data:
  203. raise Exception('remote socket closed')
  204. else:
  205. # accumulate
  206. self.headerbuffer.extend(data)
  207. if len(self.headerbuffer) >= self.maxheader:
  208. raise Exception('header exceeded allowable size')
  209. # indicates end of HTTP header
  210. if b'\r\n\r\n' in self.headerbuffer:
  211. self.request = HTTPRequest(self.headerbuffer)
  212. # handshake rfc 6455
  213. try:
  214. key = self.request.headers['Sec-WebSocket-Key']
  215. k = key.encode('ascii') + GUID_STR.encode('ascii')
  216. k_s = base64.b64encode(hashlib.sha1(k).digest()).decode('ascii')
  217. hStr = HANDSHAKE_STR % {'acceptstr': k_s}
  218. self.sendq.append((BINARY, hStr.encode('ascii')))
  219. self.handshaked = True
  220. self.handleConnected()
  221. except Exception as e:
  222. raise Exception('handshake failed: %s', str(e))
  223. # else do normal data
  224. else:
  225. data = self.client.recv(8192)
  226. if not data:
  227. raise Exception("remote socket closed")
  228. if VER >= 3:
  229. for d in data:
  230. self._parseMessage(d)
  231. else:
  232. for d in data:
  233. self._parseMessage(ord(d))
  234. def close(self, status = 1000, reason = u''):
  235. """
  236. Send Close frame to the client. The underlying socket is only closed
  237. when the client acknowledges the Close frame.
  238. status is the closing identifier.
  239. reason is the reason for the close.
  240. """
  241. try:
  242. if self.closed is False:
  243. close_msg = bytearray()
  244. close_msg.extend(struct.pack("!H", status))
  245. if _check_unicode(reason):
  246. close_msg.extend(reason.encode('utf-8'))
  247. else:
  248. close_msg.extend(reason)
  249. self._sendMessage(False, CLOSE, close_msg)
  250. finally:
  251. self.closed = True
  252. def _sendBuffer(self, buff):
  253. size = len(buff)
  254. tosend = size
  255. already_sent = 0
  256. while tosend > 0:
  257. try:
  258. # i should be able to send a bytearray
  259. sent = self.client.send(buff[already_sent:])
  260. if sent == 0:
  261. raise RuntimeError('socket connection broken')
  262. already_sent += sent
  263. tosend -= sent
  264. except socket.error as e:
  265. # if we have full buffers then wait for them to drain and try again
  266. if e.errno in [errno.EAGAIN, errno.EWOULDBLOCK]:
  267. return buff[already_sent:]
  268. else:
  269. raise e
  270. return None
  271. def sendFragmentStart(self, data):
  272. """
  273. Send the start of a data fragment stream to a websocket client.
  274. Subsequent data should be sent using sendFragment().
  275. A fragment stream is completed when sendFragmentEnd() is called.
  276. If data is a unicode object then the frame is sent as Text.
  277. If the data is a bytearray object then the frame is sent as Binary.
  278. """
  279. opcode = BINARY
  280. if _check_unicode(data):
  281. opcode = TEXT
  282. self._sendMessage(True, opcode, data)
  283. def sendFragment(self, data):
  284. """
  285. see sendFragmentStart()
  286. If data is a unicode object then the frame is sent as Text.
  287. If the data is a bytearray object then the frame is sent as Binary.
  288. """
  289. self._sendMessage(True, STREAM, data)
  290. def sendFragmentEnd(self, data):
  291. """
  292. see sendFragmentEnd()
  293. If data is a unicode object then the frame is sent as Text.
  294. If the data is a bytearray object then the frame is sent as Binary.
  295. """
  296. self._sendMessage(False, STREAM, data)
  297. def sendMessage(self, data):
  298. """
  299. Send websocket data frame to the client.
  300. If data is a unicode object then the frame is sent as Text.
  301. If the data is a bytearray object then the frame is sent as Binary.
  302. """
  303. opcode = BINARY
  304. if _check_unicode(data):
  305. opcode = TEXT
  306. self._sendMessage(False, opcode, data)
  307. def _sendMessage(self, fin, opcode, data):
  308. payload = bytearray()
  309. b1 = 0
  310. b2 = 0
  311. if fin is False:
  312. b1 |= 0x80
  313. b1 |= opcode
  314. if _check_unicode(data):
  315. data = data.encode('utf-8')
  316. length = len(data)
  317. payload.append(b1)
  318. if length <= 125:
  319. b2 |= length
  320. payload.append(b2)
  321. elif length >= 126 and length <= 65535:
  322. b2 |= 126
  323. payload.append(b2)
  324. payload.extend(struct.pack("!H", length))
  325. else:
  326. b2 |= 127
  327. payload.append(b2)
  328. payload.extend(struct.pack("!Q", length))
  329. if length > 0:
  330. payload.extend(data)
  331. self.sendq.append((opcode, payload))
  332. def _parseMessage(self, byte):
  333. # read in the header
  334. if self.state == HEADERB1:
  335. self.fin = byte & 0x80
  336. self.opcode = byte & 0x0F
  337. self.state = HEADERB2
  338. self.index = 0
  339. self.length = 0
  340. self.lengtharray = bytearray()
  341. self.data = bytearray()
  342. rsv = byte & 0x70
  343. if rsv != 0:
  344. raise Exception('RSV bit must be 0')
  345. elif self.state == HEADERB2:
  346. mask = byte & 0x80
  347. length = byte & 0x7F
  348. if self.opcode == PING and length > 125:
  349. raise Exception('ping packet is too large')
  350. if mask == 128:
  351. self.hasmask = True
  352. else:
  353. self.hasmask = False
  354. if length <= 125:
  355. self.length = length
  356. # if we have a mask we must read it
  357. if self.hasmask is True:
  358. self.maskarray = bytearray()
  359. self.state = MASK
  360. else:
  361. # if there is no mask and no payload we are done
  362. if self.length <= 0:
  363. try:
  364. self._handlePacket()
  365. finally:
  366. self.state = self.HEADERB1
  367. self.data = bytearray()
  368. # we have no mask and some payload
  369. else:
  370. #self.index = 0
  371. self.data = bytearray()
  372. self.state = PAYLOAD
  373. elif length == 126:
  374. self.lengtharray = bytearray()
  375. self.state = LENGTHSHORT
  376. elif length == 127:
  377. self.lengtharray = bytearray()
  378. self.state = LENGTHLONG
  379. elif self.state == LENGTHSHORT:
  380. self.lengtharray.append(byte)
  381. if len(self.lengtharray) > 2:
  382. raise Exception('short length exceeded allowable size')
  383. if len(self.lengtharray) == 2:
  384. self.length = struct.unpack_from('!H', self.lengtharray)[0]
  385. if self.hasmask is True:
  386. self.maskarray = bytearray()
  387. self.state = MASK
  388. else:
  389. # if there is no mask and no payload we are done
  390. if self.length <= 0:
  391. try:
  392. self._handlePacket()
  393. finally:
  394. self.state = HEADERB1
  395. self.data = bytearray()
  396. # we have no mask and some payload
  397. else:
  398. #self.index = 0
  399. self.data = bytearray()
  400. self.state = PAYLOAD
  401. elif self.state == LENGTHLONG:
  402. self.lengtharray.append(byte)
  403. if len(self.lengtharray) > 8:
  404. raise Exception('long length exceeded allowable size')
  405. if len(self.lengtharray) == 8:
  406. self.length = struct.unpack_from('!Q', self.lengtharray)[0]
  407. if self.hasmask is True:
  408. self.maskarray = bytearray()
  409. self.state = MASK
  410. else:
  411. # if there is no mask and no payload we are done
  412. if self.length <= 0:
  413. try:
  414. self._handlePacket()
  415. finally:
  416. self.state = HEADERB1
  417. self.data = bytearray()
  418. # we have no mask and some payload
  419. else:
  420. #self.index = 0
  421. self.data = bytearray()
  422. self.state = PAYLOAD
  423. # MASK STATE
  424. elif self.state == MASK:
  425. self.maskarray.append(byte)
  426. if len(self.maskarray) > 4:
  427. raise Exception('mask exceeded allowable size')
  428. if len(self.maskarray) == 4:
  429. # if there is no mask and no payload we are done
  430. if self.length <= 0:
  431. try:
  432. self._handlePacket()
  433. finally:
  434. self.state = HEADERB1
  435. self.data = bytearray()
  436. # we have no mask and some payload
  437. else:
  438. #self.index = 0
  439. self.data = bytearray()
  440. self.state = PAYLOAD
  441. # PAYLOAD STATE
  442. elif self.state == PAYLOAD:
  443. if self.hasmask is True:
  444. self.data.append( byte ^ self.maskarray[self.index % 4] )
  445. else:
  446. self.data.append( byte )
  447. # if length exceeds allowable size then we except and remove the connection
  448. if len(self.data) >= self.maxpayload:
  449. raise Exception('payload exceeded allowable size')
  450. # check if we have processed length bytes; if so we are done
  451. if (self.index+1) == self.length:
  452. try:
  453. self._handlePacket()
  454. finally:
  455. #self.index = 0
  456. self.state = HEADERB1
  457. self.data = bytearray()
  458. else:
  459. self.index += 1
  460. class SimpleWebSocketServer(object):
  461. def __init__(self, host, port, websocketclass, selectInterval = 0.1):
  462. self.websocketclass = websocketclass
  463. self.serversocket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
  464. self.serversocket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
  465. self.serversocket.bind((host, port))
  466. self.serversocket.listen(5)
  467. self.selectInterval = selectInterval
  468. self.connections = {}
  469. self.listeners = [self.serversocket]
  470. def _decorateSocket(self, sock):
  471. return sock
  472. def _constructWebSocket(self, sock, address):
  473. return self.websocketclass(self, sock, address)
  474. def close(self):
  475. self.serversocket.close()
  476. for desc, conn in self.connections.items():
  477. conn.close()
  478. conn.handleClose()
  479. def serveforever(self):
  480. while True:
  481. writers = []
  482. for fileno in self.listeners:
  483. if fileno == self.serversocket:
  484. continue
  485. client = self.connections[fileno]
  486. if client.sendq:
  487. writers.append(fileno)
  488. if self.selectInterval:
  489. rList, wList, xList = select(self.listeners, writers, self.listeners, self.selectInterval)
  490. else:
  491. rList, wList, xList = select(self.listeners, writers, self.listeners)
  492. for ready in wList:
  493. client = self.connections[ready]
  494. try:
  495. while client.sendq:
  496. opcode, payload = client.sendq.popleft()
  497. remaining = client._sendBuffer(payload)
  498. if remaining is not None:
  499. client.sendq.appendleft((opcode, remaining))
  500. break
  501. else:
  502. if opcode == CLOSE:
  503. raise Exception('received client close')
  504. except Exception as n:
  505. client.client.close()
  506. client.handleClose()
  507. del self.connections[ready]
  508. self.listeners.remove(ready)
  509. for ready in rList:
  510. if ready == self.serversocket:
  511. try:
  512. sock, address = self.serversocket.accept()
  513. newsock = self._decorateSocket(sock)
  514. newsock.setblocking(0)
  515. fileno = newsock.fileno()
  516. self.connections[fileno] = self._constructWebSocket(newsock, address)
  517. self.listeners.append(fileno)
  518. except Exception as n:
  519. if sock is not None:
  520. sock.close()
  521. else:
  522. if ready not in self.connections:
  523. continue
  524. client = self.connections[ready]
  525. try:
  526. client._handleData()
  527. except Exception as n:
  528. client.client.close()
  529. client.handleClose()
  530. del self.connections[ready]
  531. self.listeners.remove(ready)
  532. for failed in xList:
  533. if failed == self.serversocket:
  534. self.close()
  535. raise Exception('server socket failed')
  536. else:
  537. if failed not in self.connections:
  538. continue
  539. client = self.connections[failed]
  540. client.client.close()
  541. client.handleClose()
  542. del self.connections[failed]
  543. self.listeners.remove(failed)
  544. class SimpleSSLWebSocketServer(SimpleWebSocketServer):
  545. def __init__(self, host, port, websocketclass, certfile,
  546. keyfile, version = ssl.PROTOCOL_TLSv1, selectInterval = 0.1):
  547. SimpleWebSocketServer.__init__(self, host, port,
  548. websocketclass, selectInterval)
  549. self.context = ssl.SSLContext(version)
  550. self.context.load_cert_chain(certfile, keyfile)
  551. def close(self):
  552. super(SimpleSSLWebSocketServer, self).close()
  553. def _decorateSocket(self, sock):
  554. sslsock = self.context.wrap_socket(sock, server_side=True)
  555. return sslsock
  556. def _constructWebSocket(self, sock, address):
  557. ws = self.websocketclass(self, sock, address)
  558. ws.usingssl = True
  559. return ws
  560. def serveforever(self):
  561. super(SimpleSSLWebSocketServer, self).serveforever()