test_websocket.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212
  1. # vim: tabstop=4 shiftwidth=4 softtabstop=4
  2. # Copyright(c)2013 NTT corp. All Rights Reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License"); you may
  5. # not use this file except in compliance with the License. You may obtain
  6. # a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
  12. # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
  13. # License for the specific language governing permissions and limitations
  14. # under the License.
  15. """ Unit tests for websocket """
  16. import unittest
  17. from websockify import websocket
  18. class FakeSocket:
  19. def __init__(self):
  20. self.data = b''
  21. def send(self, buf):
  22. self.data += buf
  23. return len(buf)
  24. class AcceptTestCase(unittest.TestCase):
  25. def test_success(self):
  26. ws = websocket.WebSocket()
  27. sock = FakeSocket()
  28. ws.accept(sock, {'upgrade': 'websocket',
  29. 'Sec-WebSocket-Version': '13',
  30. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  31. self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
  32. self.assertTrue(b'\r\nUpgrade: websocket\r\n' in sock.data)
  33. self.assertTrue(b'\r\nConnection: Upgrade\r\n' in sock.data)
  34. self.assertTrue(b'\r\nSec-WebSocket-Accept: pczpYSQsvE1vBpTQYjFQPcuoj6M=\r\n' in sock.data)
  35. def test_bad_version(self):
  36. ws = websocket.WebSocket()
  37. sock = FakeSocket()
  38. self.assertRaises(Exception, ws.accept,
  39. sock, {'upgrade': 'websocket',
  40. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  41. self.assertRaises(Exception, ws.accept,
  42. sock, {'upgrade': 'websocket',
  43. 'Sec-WebSocket-Version': '5',
  44. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  45. self.assertRaises(Exception, ws.accept,
  46. sock, {'upgrade': 'websocket',
  47. 'Sec-WebSocket-Version': '20',
  48. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  49. def test_bad_upgrade(self):
  50. ws = websocket.WebSocket()
  51. sock = FakeSocket()
  52. self.assertRaises(Exception, ws.accept,
  53. sock, {'Sec-WebSocket-Version': '13',
  54. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  55. self.assertRaises(Exception, ws.accept,
  56. sock, {'upgrade': 'websocket2',
  57. 'Sec-WebSocket-Version': '13',
  58. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  59. def test_missing_key(self):
  60. ws = websocket.WebSocket()
  61. sock = FakeSocket()
  62. self.assertRaises(Exception, ws.accept,
  63. sock, {'upgrade': 'websocket',
  64. 'Sec-WebSocket-Version': '13'})
  65. def test_protocol(self):
  66. class ProtoSocket(websocket.WebSocket):
  67. def select_subprotocol(self, protocol):
  68. return 'gazonk'
  69. ws = ProtoSocket()
  70. sock = FakeSocket()
  71. ws.accept(sock, {'upgrade': 'websocket',
  72. 'Sec-WebSocket-Version': '13',
  73. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
  74. 'Sec-WebSocket-Protocol': 'foobar gazonk'})
  75. self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
  76. self.assertTrue(b'\r\nSec-WebSocket-Protocol: gazonk\r\n' in sock.data)
  77. def test_no_protocol(self):
  78. ws = websocket.WebSocket()
  79. sock = FakeSocket()
  80. ws.accept(sock, {'upgrade': 'websocket',
  81. 'Sec-WebSocket-Version': '13',
  82. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  83. self.assertEqual(sock.data[:13], b'HTTP/1.1 101 ')
  84. self.assertFalse(b'\r\nSec-WebSocket-Protocol:' in sock.data)
  85. def test_missing_protocol(self):
  86. ws = websocket.WebSocket()
  87. sock = FakeSocket()
  88. self.assertRaises(Exception, ws.accept,
  89. sock, {'upgrade': 'websocket',
  90. 'Sec-WebSocket-Version': '13',
  91. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
  92. 'Sec-WebSocket-Protocol': 'foobar gazonk'})
  93. def test_protocol(self):
  94. class ProtoSocket(websocket.WebSocket):
  95. def select_subprotocol(self, protocol):
  96. return 'oddball'
  97. ws = ProtoSocket()
  98. sock = FakeSocket()
  99. self.assertRaises(Exception, ws.accept,
  100. sock, {'upgrade': 'websocket',
  101. 'Sec-WebSocket-Version': '13',
  102. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q==',
  103. 'Sec-WebSocket-Protocol': 'foobar gazonk'})
  104. class PingPongTest(unittest.TestCase):
  105. def setUp(self):
  106. self.ws = websocket.WebSocket()
  107. self.sock = FakeSocket()
  108. self.ws.accept(self.sock, {'upgrade': 'websocket',
  109. 'Sec-WebSocket-Version': '13',
  110. 'Sec-WebSocket-Key': 'DKURYVK9cRFul1vOZVA56Q=='})
  111. self.assertEqual(self.sock.data[:13], b'HTTP/1.1 101 ')
  112. self.sock.data = b''
  113. def test_ping(self):
  114. self.ws.ping()
  115. self.assertEqual(self.sock.data, b'\x89\x00')
  116. def test_pong(self):
  117. self.ws.pong()
  118. self.assertEqual(self.sock.data, b'\x8a\x00')
  119. def test_ping_data(self):
  120. self.ws.ping(b'foo')
  121. self.assertEqual(self.sock.data, b'\x89\x03foo')
  122. def test_pong_data(self):
  123. self.ws.pong(b'foo')
  124. self.assertEqual(self.sock.data, b'\x8a\x03foo')
  125. class HyBiEncodeDecodeTestCase(unittest.TestCase):
  126. def test_decode_hybi_text(self):
  127. buf = b'\x81\x85\x37\xfa\x21\x3d\x7f\x9f\x4d\x51\x58'
  128. ws = websocket.WebSocket()
  129. res = ws._decode_hybi(buf)
  130. self.assertEqual(res['fin'], 1)
  131. self.assertEqual(res['opcode'], 0x1)
  132. self.assertEqual(res['masked'], True)
  133. self.assertEqual(res['length'], len(buf))
  134. self.assertEqual(res['payload'], b'Hello')
  135. def test_decode_hybi_binary(self):
  136. buf = b'\x82\x04\x01\x02\x03\x04'
  137. ws = websocket.WebSocket()
  138. res = ws._decode_hybi(buf)
  139. self.assertEqual(res['fin'], 1)
  140. self.assertEqual(res['opcode'], 0x2)
  141. self.assertEqual(res['length'], len(buf))
  142. self.assertEqual(res['payload'], b'\x01\x02\x03\x04')
  143. def test_decode_hybi_extended_16bit_binary(self):
  144. data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
  145. buf = b'\x82\x7e\x01\x04' + data
  146. ws = websocket.WebSocket()
  147. res = ws._decode_hybi(buf)
  148. self.assertEqual(res['fin'], 1)
  149. self.assertEqual(res['opcode'], 0x2)
  150. self.assertEqual(res['length'], len(buf))
  151. self.assertEqual(res['payload'], data)
  152. def test_decode_hybi_extended_64bit_binary(self):
  153. data = (b'\x01\x02\x03\x04' * 65) # len > 126 -- len == 260
  154. buf = b'\x82\x7f\x00\x00\x00\x00\x00\x00\x01\x04' + data
  155. ws = websocket.WebSocket()
  156. res = ws._decode_hybi(buf)
  157. self.assertEqual(res['fin'], 1)
  158. self.assertEqual(res['opcode'], 0x2)
  159. self.assertEqual(res['length'], len(buf))
  160. self.assertEqual(res['payload'], data)
  161. def test_decode_hybi_multi(self):
  162. buf1 = b'\x01\x03\x48\x65\x6c'
  163. buf2 = b'\x80\x02\x6c\x6f'
  164. ws = websocket.WebSocket()
  165. res1 = ws._decode_hybi(buf1)
  166. self.assertEqual(res1['fin'], 0)
  167. self.assertEqual(res1['opcode'], 0x1)
  168. self.assertEqual(res1['length'], len(buf1))
  169. self.assertEqual(res1['payload'], b'Hel')
  170. res2 = ws._decode_hybi(buf2)
  171. self.assertEqual(res2['fin'], 1)
  172. self.assertEqual(res2['opcode'], 0x0)
  173. self.assertEqual(res2['length'], len(buf2))
  174. self.assertEqual(res2['payload'], b'lo')
  175. def test_encode_hybi_basic(self):
  176. ws = websocket.WebSocket()
  177. res = ws._encode_hybi(0x1, b'Hello')
  178. expected = b'\x81\x05\x48\x65\x6c\x6c\x6f'
  179. self.assertEqual(res, expected)