| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400 |
- # vim: tabstop=4 shiftwidth=4 softtabstop=4
- # Copyright(c)2013 NTT corp. All Rights Reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License"); you may
- # not use this file except in compliance with the License. You may obtain
- # a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- # License for the specific language governing permissions and limitations
- # under the License.
- """ Unit tests for websockifyserver """
- import errno
- import os
- import logging
- import select
- import shutil
- import socket
- import ssl
- from unittest.mock import patch, MagicMock, ANY
- import sys
- import tempfile
- import unittest
- import socket
- import signal
- from http.server import BaseHTTPRequestHandler
- from io import StringIO
- from io import BytesIO
- from websockify import websockifyserver
- def raise_oserror(*args, **kwargs):
- raise OSError('fake error')
- class FakeSocket(object):
- def __init__(self, data=b''):
- self._data = data
- def recv(self, amt, flags=None):
- res = self._data[0:amt]
- if not (flags & socket.MSG_PEEK):
- self._data = self._data[amt:]
- return res
- def makefile(self, mode='r', buffsize=None):
- if 'b' in mode:
- return BytesIO(self._data)
- else:
- return StringIO(self._data.decode('latin_1'))
- class WebSockifyRequestHandlerTestCase(unittest.TestCase):
- def setUp(self):
- super(WebSockifyRequestHandlerTestCase, self).setUp()
- self.tmpdir = tempfile.mkdtemp('-websockify-tests')
- # Mock this out cause it screws tests up
- patch('os.chdir').start()
- def tearDown(self):
- """Called automatically after each test."""
- patch.stopall()
- os.rmdir(self.tmpdir)
- super(WebSockifyRequestHandlerTestCase, self).tearDown()
- def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
- **kwargs):
- web = kwargs.pop('web', self.tmpdir)
- return websockifyserver.WebSockifyServer(
- handler_class, listen_host='localhost',
- listen_port=80, key=self.tmpdir, web=web,
- record=self.tmpdir, daemon=False, ssl_only=0, idle_timeout=1,
- **kwargs)
- @patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
- def test_normal_get_with_only_upgrade_returns_error(self, send_error):
- server = self._get_server(web=None)
- handler = websockifyserver.WebSockifyRequestHandler(
- FakeSocket(b'GET /tmp.txt HTTP/1.1'), '127.0.0.1', server)
- handler.do_GET()
- send_error.assert_called_with(405, ANY)
- @patch('websockify.websockifyserver.WebSockifyRequestHandler.send_error')
- def test_list_dir_with_file_only_returns_error(self, send_error):
- server = self._get_server(file_only=True)
- handler = websockifyserver.WebSockifyRequestHandler(
- FakeSocket(b'GET / HTTP/1.1'), '127.0.0.1', server)
- handler.path = '/'
- handler.do_GET()
- send_error.assert_called_with(404, ANY)
- class WebSockifyServerTestCase(unittest.TestCase):
- def setUp(self):
- super(WebSockifyServerTestCase, self).setUp()
- self.tmpdir = tempfile.mkdtemp('-websockify-tests')
- # Mock this out cause it screws tests up
- patch('os.chdir').start()
- def tearDown(self):
- """Called automatically after each test."""
- patch.stopall()
- os.rmdir(self.tmpdir)
- super(WebSockifyServerTestCase, self).tearDown()
- def _get_server(self, handler_class=websockifyserver.WebSockifyRequestHandler,
- **kwargs):
- return websockifyserver.WebSockifyServer(
- handler_class, listen_host='localhost',
- listen_port=80, key=self.tmpdir, web=self.tmpdir,
- record=self.tmpdir, **kwargs)
- def test_daemonize_raises_error_while_closing_fds(self):
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
- patch('os.fork').start().return_value = 0
- patch('signal.signal').start()
- patch('os.setsid').start()
- patch('os.close').start().side_effect = raise_oserror
- self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
- def test_daemonize_ignores_ebadf_error_while_closing_fds(self):
- def raise_oserror_ebadf(fd):
- raise OSError(errno.EBADF, 'fake error')
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
- patch('os.fork').start().return_value = 0
- patch('signal.signal').start()
- patch('os.setsid').start()
- patch('os.close').start().side_effect = raise_oserror_ebadf
- patch('os.open').start().side_effect = raise_oserror
- self.assertRaises(OSError, server.daemonize, keepfd=None, chdir='./')
- def test_handshake_fails_on_not_ready(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([], [], [])
- patch('select.select').start().side_effect = fake_select
- self.assertRaises(
- websockifyserver.WebSockifyServer.EClose, server.do_handshake,
- FakeSocket(), '127.0.0.1')
- def test_empty_handshake_fails(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
- sock = FakeSocket('')
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- patch('select.select').start().side_effect = fake_select
- self.assertRaises(
- websockifyserver.WebSockifyServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
- def test_handshake_policy_request(self):
- # TODO(directxman12): implement
- pass
- def test_handshake_ssl_only_without_ssl_raises_error(self):
- server = self._get_server(daemon=True, ssl_only=1, idle_timeout=1)
- sock = FakeSocket(b'some initial data')
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- patch('select.select').start().side_effect = fake_select
- self.assertRaises(
- websockifyserver.WebSockifyServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
- def test_do_handshake_no_ssl(self):
- class FakeHandler(object):
- CALLED = False
- def __init__(self, *args, **kwargs):
- type(self).CALLED = True
- FakeHandler.CALLED = False
- server = self._get_server(
- handler_class=FakeHandler, daemon=True,
- ssl_only=0, idle_timeout=1)
- sock = FakeSocket(b'some initial data')
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- patch('select.select').start().side_effect = fake_select
- self.assertEqual(server.do_handshake(sock, '127.0.0.1'), sock)
- self.assertTrue(FakeHandler.CALLED, True)
- def test_do_handshake_ssl(self):
- # TODO(directxman12): implement this
- pass
- def test_do_handshake_ssl_without_ssl_raises_error(self):
- # TODO(directxman12): implement this
- pass
- def test_do_handshake_ssl_without_cert_raises_error(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1,
- cert='afdsfasdafdsafdsafdsafdas')
- sock = FakeSocket(b"\x16some ssl data")
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- patch('select.select').start().side_effect = fake_select
- self.assertRaises(
- websockifyserver.WebSockifyServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
- def test_do_handshake_ssl_error_eof_raises_close_error(self):
- server = self._get_server(daemon=True, ssl_only=0, idle_timeout=1)
- sock = FakeSocket(b"\x16some ssl data")
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- def fake_wrap_socket(*args, **kwargs):
- raise ssl.SSLError(ssl.SSL_ERROR_EOF)
- class fake_create_default_context():
- def __init__(self, purpose):
- self.verify_mode = None
- self.options = 0
- def load_cert_chain(self, certfile, keyfile, password):
- pass
- def set_default_verify_paths(self):
- pass
- def load_verify_locations(self, cafile):
- pass
- def wrap_socket(self, *args, **kwargs):
- raise ssl.SSLError(ssl.SSL_ERROR_EOF)
- patch('select.select').start().side_effect = fake_select
- patch('ssl.create_default_context').start().side_effect = fake_create_default_context
- self.assertRaises(
- websockifyserver.WebSockifyServer.EClose, server.do_handshake,
- sock, '127.0.0.1')
- def test_do_handshake_ssl_sets_ciphers(self):
- test_ciphers = 'TEST-CIPHERS-1:TEST-CIPHER-2'
- class FakeHandler(object):
- def __init__(self, *args, **kwargs):
- pass
- server = self._get_server(handler_class=FakeHandler, daemon=True,
- idle_timeout=1, ssl_ciphers=test_ciphers)
- sock = FakeSocket(b"\x16some ssl data")
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- class fake_create_default_context():
- CIPHERS = ''
- def __init__(self, purpose):
- self.verify_mode = None
- self.options = 0
- def load_cert_chain(self, certfile, keyfile, password):
- pass
- def set_default_verify_paths(self):
- pass
- def load_verify_locations(self, cafile):
- pass
- def wrap_socket(self, *args, **kwargs):
- pass
- def set_ciphers(self, ciphers_to_set):
- fake_create_default_context.CIPHERS = ciphers_to_set
- patch('select.select').start().side_effect = fake_select
- patch('ssl.create_default_context').start().side_effect = fake_create_default_context
- server.do_handshake(sock, '127.0.0.1')
- self.assertEqual(fake_create_default_context.CIPHERS, test_ciphers)
- def test_do_handshake_ssl_sets_opions(self):
- test_options = 0xCAFEBEEF
- class FakeHandler(object):
- def __init__(self, *args, **kwargs):
- pass
- server = self._get_server(handler_class=FakeHandler, daemon=True,
- idle_timeout=1, ssl_options=test_options)
- sock = FakeSocket(b"\x16some ssl data")
- def fake_select(rlist, wlist, xlist, timeout=None):
- return ([sock], [], [])
- class fake_create_default_context(object):
- OPTIONS = 0
- def __init__(self, purpose):
- self.verify_mode = None
- self._options = 0
- def load_cert_chain(self, certfile, keyfile, password):
- pass
- def set_default_verify_paths(self):
- pass
- def load_verify_locations(self, cafile):
- pass
- def wrap_socket(self, *args, **kwargs):
- pass
- def get_options(self):
- return self._options
- def set_options(self, val):
- fake_create_default_context.OPTIONS = val
- options = property(get_options, set_options)
- patch('select.select').start().side_effect = fake_select
- patch('ssl.create_default_context').start().side_effect = fake_create_default_context
- server.do_handshake(sock, '127.0.0.1')
- self.assertEqual(fake_create_default_context.OPTIONS, test_options)
- def test_fallback_sigchld_handler(self):
- # TODO(directxman12): implement this
- pass
- def test_start_server_error(self):
- server = self._get_server(daemon=False, ssl_only=1, idle_timeout=1)
- sock = server.socket('localhost')
- def fake_select(rlist, wlist, xlist, timeout=None):
- raise Exception("fake error")
- patch('websockify.websockifyserver.WebSockifyServer.socket').start()
- patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
- patch('select.select').start().side_effect = fake_select
- server.start_server()
- def test_start_server_keyboardinterrupt(self):
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost')
- def fake_select(rlist, wlist, xlist, timeout=None):
- raise KeyboardInterrupt
- patch('websockify.websockifyserver.WebSockifyServer.socket').start()
- patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
- patch('select.select').start().side_effect = fake_select
- server.start_server()
- def test_start_server_systemexit(self):
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost')
- def fake_select(rlist, wlist, xlist, timeout=None):
- sys.exit()
- patch('websockify.websockifyserver.WebSockifyServer.socket').start()
- patch('websockify.websockifyserver.WebSockifyServer.daemonize').start()
- patch('select.select').start().side_effect = fake_select
- server.start_server()
- def test_socket_set_keepalive_options(self):
- keepcnt = 12
- keepidle = 34
- keepintvl = 56
- server = self._get_server(daemon=False, ssl_only=0, idle_timeout=1)
- sock = server.socket('localhost',
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
- if hasattr(socket, 'TCP_KEEPCNT'):
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPCNT), keepcnt)
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPIDLE), keepidle)
- self.assertEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPINTVL), keepintvl)
- sock = server.socket('localhost',
- tcp_keepalive=False,
- tcp_keepcnt=keepcnt,
- tcp_keepidle=keepidle,
- tcp_keepintvl=keepintvl)
- if hasattr(socket, 'TCP_KEEPCNT'):
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPCNT), keepcnt)
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPIDLE), keepidle)
- self.assertNotEqual(sock.getsockopt(socket.SOL_TCP,
- socket.TCP_KEEPINTVL), keepintvl)
|