891 lines
36 KiB
Python
891 lines
36 KiB
Python
import gzip
|
|
import importlib
|
|
import json
|
|
import logging
|
|
import sys
|
|
import time
|
|
import unittest
|
|
import zlib
|
|
|
|
import six
|
|
if six.PY3:
|
|
from unittest import mock
|
|
else:
|
|
import mock
|
|
|
|
from engineio import exceptions
|
|
from engineio import packet
|
|
from engineio import payload
|
|
from engineio import server
|
|
|
|
|
|
original_import_module = importlib.import_module
|
|
|
|
|
|
def _mock_import(module, *args, **kwargs):
|
|
if module.startswith('engineio.'):
|
|
return original_import_module(module, *args, **kwargs)
|
|
return module
|
|
|
|
|
|
class TestServer(unittest.TestCase):
|
|
_mock_async = mock.MagicMock()
|
|
_mock_async._async = {
|
|
'threading': 't',
|
|
'thread_class': 'tc',
|
|
'queue': 'q',
|
|
'queue_class': 'qc',
|
|
'websocket': 'w',
|
|
'websocket_class': 'wc'
|
|
}
|
|
|
|
def _get_mock_socket(self):
|
|
mock_socket = mock.MagicMock()
|
|
mock_socket.closed = False
|
|
mock_socket.closing = False
|
|
mock_socket.upgraded = False
|
|
return mock_socket
|
|
|
|
@classmethod
|
|
def setUpClass(cls):
|
|
server.Server._default_monitor_clients = False
|
|
|
|
@classmethod
|
|
def tearDownClass(cls):
|
|
server.Server._default_monitor_clients = True
|
|
|
|
def setUp(self):
|
|
logging.getLogger('engineio').setLevel(logging.NOTSET)
|
|
|
|
def tearDown(self):
|
|
# restore JSON encoder, in case a test changed it
|
|
packet.Packet.json = json
|
|
|
|
def test_is_asyncio_based(self):
|
|
s = server.Server()
|
|
self.assertEqual(s.is_asyncio_based(), False)
|
|
|
|
def test_async_modes(self):
|
|
s = server.Server()
|
|
self.assertEqual(s.async_modes(), ['eventlet', 'gevent_uwsgi',
|
|
'gevent', 'threading'])
|
|
|
|
def test_create(self):
|
|
kwargs = {
|
|
'ping_timeout': 1,
|
|
'ping_interval': 2,
|
|
'max_http_buffer_size': 3,
|
|
'allow_upgrades': False,
|
|
'http_compression': False,
|
|
'compression_threshold': 4,
|
|
'cookie': 'foo',
|
|
'cors_allowed_origins': ['foo', 'bar', 'baz'],
|
|
'cors_credentials': False,
|
|
'async_handlers': False}
|
|
s = server.Server(**kwargs)
|
|
for arg in six.iterkeys(kwargs):
|
|
self.assertEqual(getattr(s, arg), kwargs[arg])
|
|
|
|
def test_create_ignores_kwargs(self):
|
|
server.Server(foo='bar') # this should not raise
|
|
|
|
def test_async_mode_threading(self):
|
|
s = server.Server(async_mode='threading')
|
|
self.assertEqual(s.async_mode, 'threading')
|
|
|
|
import threading
|
|
try:
|
|
import queue
|
|
except ImportError:
|
|
import Queue as queue
|
|
|
|
self.assertEqual(s._async['threading'], threading)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], queue)
|
|
self.assertEqual(s._async['queue_class'], 'Queue')
|
|
self.assertEqual(s._async['websocket'], None)
|
|
self.assertEqual(s._async['websocket_class'], None)
|
|
|
|
def test_async_mode_eventlet(self):
|
|
s = server.Server(async_mode='eventlet')
|
|
self.assertEqual(s.async_mode, 'eventlet')
|
|
|
|
from eventlet.green import threading
|
|
from eventlet import queue
|
|
from engineio.async_drivers import eventlet as async_eventlet
|
|
|
|
self.assertEqual(s._async['threading'], threading)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], queue)
|
|
self.assertEqual(s._async['queue_class'], 'Queue')
|
|
self.assertEqual(s._async['websocket'], async_eventlet)
|
|
self.assertEqual(s._async['websocket_class'], 'WebSocketWSGI')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_gevent_uwsgi(self, import_module):
|
|
sys.modules['gevent'] = mock.MagicMock()
|
|
sys.modules['uwsgi'] = mock.MagicMock()
|
|
s = server.Server(async_mode='gevent_uwsgi')
|
|
self.assertEqual(s.async_mode, 'gevent_uwsgi')
|
|
|
|
from engineio.async_drivers import gevent_uwsgi as async_gevent_uwsgi
|
|
|
|
self.assertEqual(s._async['threading'], async_gevent_uwsgi)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], 'gevent.queue')
|
|
self.assertEqual(s._async['queue_class'], 'JoinableQueue')
|
|
self.assertEqual(s._async['websocket'], async_gevent_uwsgi)
|
|
self.assertEqual(s._async['websocket_class'], 'uWSGIWebSocket')
|
|
del sys.modules['gevent']
|
|
del sys.modules['uwsgi']
|
|
del sys.modules['engineio.async_drivers.gevent_uwsgi']
|
|
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_gevent_uwsgi_without_uwsgi(self, import_module):
|
|
sys.modules['gevent'] = mock.MagicMock()
|
|
sys.modules['uwsgi'] = None
|
|
self.assertRaises(ValueError, server.Server,
|
|
async_mode='gevent_uwsgi')
|
|
del sys.modules['gevent']
|
|
del sys.modules['uwsgi']
|
|
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_gevent_uwsgi_without_websocket(self, import_module):
|
|
sys.modules['gevent'] = mock.MagicMock()
|
|
sys.modules['uwsgi'] = mock.MagicMock()
|
|
del sys.modules['uwsgi'].websocket_handshake
|
|
s = server.Server(async_mode='gevent_uwsgi')
|
|
self.assertEqual(s.async_mode, 'gevent_uwsgi')
|
|
|
|
from engineio.async_drivers import gevent_uwsgi as async_gevent_uwsgi
|
|
|
|
self.assertEqual(s._async['threading'], async_gevent_uwsgi)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], 'gevent.queue')
|
|
self.assertEqual(s._async['queue_class'], 'JoinableQueue')
|
|
self.assertEqual(s._async['websocket'], None)
|
|
self.assertEqual(s._async['websocket_class'], None)
|
|
del sys.modules['gevent']
|
|
del sys.modules['uwsgi']
|
|
del sys.modules['engineio.async_drivers.gevent_uwsgi']
|
|
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_gevent(self, import_module):
|
|
sys.modules['gevent'] = mock.MagicMock()
|
|
sys.modules['geventwebsocket'] = 'geventwebsocket'
|
|
s = server.Server(async_mode='gevent')
|
|
self.assertEqual(s.async_mode, 'gevent')
|
|
|
|
from engineio.async_drivers import gevent as async_gevent
|
|
|
|
self.assertEqual(s._async['threading'], async_gevent)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], 'gevent.queue')
|
|
self.assertEqual(s._async['queue_class'], 'JoinableQueue')
|
|
self.assertEqual(s._async['websocket'], async_gevent)
|
|
self.assertEqual(s._async['websocket_class'], 'WebSocketWSGI')
|
|
del sys.modules['gevent']
|
|
del sys.modules['geventwebsocket']
|
|
del sys.modules['engineio.async_drivers.gevent']
|
|
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_gevent_without_websocket(self, import_module):
|
|
sys.modules['gevent'] = mock.MagicMock()
|
|
sys.modules['geventwebsocket'] = None
|
|
s = server.Server(async_mode='gevent')
|
|
self.assertEqual(s.async_mode, 'gevent')
|
|
|
|
from engineio.async_drivers import gevent as async_gevent
|
|
|
|
self.assertEqual(s._async['threading'], async_gevent)
|
|
self.assertEqual(s._async['thread_class'], 'Thread')
|
|
self.assertEqual(s._async['queue'], 'gevent.queue')
|
|
self.assertEqual(s._async['queue_class'], 'JoinableQueue')
|
|
self.assertEqual(s._async['websocket'], None)
|
|
self.assertEqual(s._async['websocket_class'], None)
|
|
del sys.modules['gevent']
|
|
del sys.modules['geventwebsocket']
|
|
del sys.modules['engineio.async_drivers.gevent']
|
|
|
|
@unittest.skipIf(sys.version_info < (3, 5), 'only for Python 3.5+')
|
|
@mock.patch('importlib.import_module', side_effect=_mock_import)
|
|
def test_async_mode_aiohttp(self, import_module):
|
|
sys.modules['aiohttp'] = mock.MagicMock()
|
|
self.assertRaises(ValueError, server.Server, async_mode='aiohttp')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=[ImportError])
|
|
def test_async_mode_invalid(self, import_module):
|
|
self.assertRaises(ValueError, server.Server, async_mode='foo')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=[_mock_async])
|
|
def test_async_mode_auto_eventlet(self, import_module):
|
|
s = server.Server()
|
|
self.assertEqual(s.async_mode, 'eventlet')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=[ImportError,
|
|
_mock_async])
|
|
def test_async_mode_auto_gevent_uwsgi(self, import_module):
|
|
s = server.Server()
|
|
self.assertEqual(s.async_mode, 'gevent_uwsgi')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=[ImportError,
|
|
ImportError,
|
|
_mock_async])
|
|
def test_async_mode_auto_gevent(self, import_module):
|
|
s = server.Server()
|
|
self.assertEqual(s.async_mode, 'gevent')
|
|
|
|
@mock.patch('importlib.import_module', side_effect=[ImportError,
|
|
ImportError,
|
|
ImportError,
|
|
_mock_async])
|
|
def test_async_mode_auto_threading(self, import_module):
|
|
s = server.Server()
|
|
self.assertEqual(s.async_mode, 'threading')
|
|
|
|
def test_generate_id(self):
|
|
s = server.Server()
|
|
self.assertNotEqual(s._generate_id(), s._generate_id())
|
|
|
|
def test_on_event(self):
|
|
s = server.Server()
|
|
|
|
@s.on('connect')
|
|
def foo():
|
|
pass
|
|
s.on('disconnect', foo)
|
|
|
|
self.assertEqual(s.handlers['connect'], foo)
|
|
self.assertEqual(s.handlers['disconnect'], foo)
|
|
|
|
def test_on_event_invalid(self):
|
|
s = server.Server()
|
|
self.assertRaises(ValueError, s.on, 'invalid')
|
|
|
|
def test_trigger_event(self):
|
|
s = server.Server()
|
|
f = {}
|
|
|
|
@s.on('connect')
|
|
def foo(sid, environ):
|
|
return sid + environ
|
|
|
|
@s.on('message')
|
|
def bar(sid, data):
|
|
f['bar'] = sid + data
|
|
return 'bar'
|
|
|
|
r = s._trigger_event('connect', 1, 2, run_async=False)
|
|
self.assertEqual(r, 3)
|
|
r = s._trigger_event('message', 3, 4, run_async=True)
|
|
r.join()
|
|
self.assertEqual(f['bar'], 7)
|
|
r = s._trigger_event('message', 5, 6)
|
|
self.assertEqual(r, 'bar')
|
|
|
|
def test_trigger_event_error(self):
|
|
s = server.Server()
|
|
|
|
@s.on('connect')
|
|
def foo(sid, environ):
|
|
return 1 / 0
|
|
|
|
@s.on('message')
|
|
def bar(sid, data):
|
|
return 1 / 0
|
|
|
|
r = s._trigger_event('connect', 1, 2, run_async=False)
|
|
self.assertEqual(r, False)
|
|
r = s._trigger_event('message', 3, 4, run_async=False)
|
|
self.assertEqual(r, None)
|
|
|
|
def test_close_one_socket(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
s.sockets['foo'] = mock_socket
|
|
s.disconnect('foo')
|
|
self.assertEqual(mock_socket.close.call_count, 1)
|
|
self.assertNotIn('foo', s.sockets)
|
|
|
|
def test_close_all_sockets(self):
|
|
s = server.Server()
|
|
mock_sockets = {}
|
|
for sid in ['foo', 'bar', 'baz']:
|
|
mock_sockets[sid] = self._get_mock_socket()
|
|
s.sockets[sid] = mock_sockets[sid]
|
|
s.disconnect()
|
|
for socket in six.itervalues(mock_sockets):
|
|
self.assertEqual(socket.close.call_count, 1)
|
|
self.assertEqual(s.sockets, {})
|
|
|
|
def test_upgrades(self):
|
|
s = server.Server()
|
|
s.sockets['foo'] = self._get_mock_socket()
|
|
self.assertEqual(s._upgrades('foo', 'polling'), ['websocket'])
|
|
self.assertEqual(s._upgrades('foo', 'websocket'), [])
|
|
s.sockets['foo'].upgraded = True
|
|
self.assertEqual(s._upgrades('foo', 'polling'), [])
|
|
self.assertEqual(s._upgrades('foo', 'websocket'), [])
|
|
s.allow_upgrades = False
|
|
s.sockets['foo'].upgraded = True
|
|
self.assertEqual(s._upgrades('foo', 'polling'), [])
|
|
self.assertEqual(s._upgrades('foo', 'websocket'), [])
|
|
|
|
def test_transport(self):
|
|
s = server.Server()
|
|
s.sockets['foo'] = self._get_mock_socket()
|
|
s.sockets['foo'].upgraded = False
|
|
s.sockets['bar'] = self._get_mock_socket()
|
|
s.sockets['bar'].upgraded = True
|
|
self.assertEqual(s.transport('foo'), 'polling')
|
|
self.assertEqual(s.transport('bar'), 'websocket')
|
|
|
|
def test_bad_session(self):
|
|
s = server.Server()
|
|
s.sockets['foo'] = 'client'
|
|
self.assertRaises(KeyError, s._get_socket, 'bar')
|
|
|
|
def test_closed_socket(self):
|
|
s = server.Server()
|
|
s.sockets['foo'] = self._get_mock_socket()
|
|
s.sockets['foo'].closed = True
|
|
self.assertRaises(KeyError, s._get_socket, 'foo')
|
|
|
|
def test_jsonp_not_supported(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'j=abc'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
|
|
def test_connect(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(len(s.sockets), 1)
|
|
self.assertEqual(start_response.call_count, 1)
|
|
self.assertEqual(start_response.call_args[0][0], '200 OK')
|
|
self.assertIn(('Content-Type', 'application/octet-stream'),
|
|
start_response.call_args[0][1])
|
|
self.assertEqual(len(r), 1)
|
|
packets = payload.Payload(encoded_payload=r[0]).packets
|
|
self.assertEqual(len(packets), 1)
|
|
self.assertEqual(packets[0].packet_type, packet.OPEN)
|
|
self.assertIn('upgrades', packets[0].data)
|
|
self.assertEqual(packets[0].data['upgrades'], ['websocket'])
|
|
self.assertIn('sid', packets[0].data)
|
|
|
|
def test_connect_no_upgrades(self):
|
|
s = server.Server(allow_upgrades=False)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
packets = payload.Payload(encoded_payload=r[0]).packets
|
|
self.assertEqual(packets[0].data['upgrades'], [])
|
|
|
|
def test_connect_b64_with_1(self):
|
|
s = server.Server(allow_upgrades=False)
|
|
s._generate_id = mock.MagicMock(return_value='1')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=1'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertTrue(start_response.call_args[0][0], '200 OK')
|
|
self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
|
|
start_response.call_args[0][1])
|
|
s.send('1', b'\x00\x01\x02', binary=True)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=1'}
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(r[0], b'6:b4AAEC')
|
|
|
|
def test_connect_b64_with_true(self):
|
|
s = server.Server(allow_upgrades=False)
|
|
s._generate_id = mock.MagicMock(return_value='1')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=true'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertTrue(start_response.call_args[0][0], '200 OK')
|
|
self.assertIn(('Content-Type', 'text/plain; charset=UTF-8'),
|
|
start_response.call_args[0][1])
|
|
s.send('1', b'\x00\x01\x02', binary=True)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=true'}
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(r[0], b'6:b4AAEC')
|
|
|
|
def test_connect_b64_with_0(self):
|
|
s = server.Server(allow_upgrades=False)
|
|
s._generate_id = mock.MagicMock(return_value='1')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=0'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertTrue(start_response.call_args[0][0], '200 OK')
|
|
self.assertIn(('Content-Type', 'application/octet-stream'),
|
|
start_response.call_args[0][1])
|
|
s.send('1', b'\x00\x01\x02', binary=True)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=0'}
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(r[0], b'\x01\x04\xff\x04\x00\x01\x02')
|
|
|
|
def test_connect_b64_with_false(self):
|
|
s = server.Server(allow_upgrades=False)
|
|
s._generate_id = mock.MagicMock(return_value='1')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'b64=false'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertTrue(start_response.call_args[0][0], '200 OK')
|
|
self.assertIn(('Content-Type', 'application/octet-stream'),
|
|
start_response.call_args[0][1])
|
|
s.send('1', b'\x00\x01\x02', binary=True)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=1&b64=false'}
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(r[0], b'\x01\x04\xff\x04\x00\x01\x02')
|
|
|
|
def test_connect_custom_ping_times(self):
|
|
s = server.Server(ping_timeout=123, ping_interval=456)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
packets = payload.Payload(encoded_payload=r[0]).packets
|
|
self.assertEqual(packets[0].data['pingTimeout'], 123000)
|
|
self.assertEqual(packets[0].data['pingInterval'], 456000)
|
|
|
|
@mock.patch('engineio.socket.Socket.poll',
|
|
side_effect=exceptions.QueueEmpty)
|
|
def test_connect_bad_poll(self, poll):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
|
|
@mock.patch('engineio.socket.Socket',
|
|
return_value=mock.MagicMock(connected=False, closed=False))
|
|
def test_connect_transport_websocket(self, Socket):
|
|
s = server.Server()
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
environ = {'REQUEST_METHOD': 'GET',
|
|
'QUERY_STRING': 'transport=websocket'}
|
|
start_response = mock.MagicMock()
|
|
# force socket to stay open, so that we can check it later
|
|
Socket().closed = False
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(s.sockets['123'].send.call_args[0][0].packet_type,
|
|
packet.OPEN)
|
|
|
|
@mock.patch('engineio.socket.Socket',
|
|
return_value=mock.MagicMock(connected=False, closed=False))
|
|
def test_connect_transport_websocket_closed(self, Socket):
|
|
s = server.Server()
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
environ = {'REQUEST_METHOD': 'GET',
|
|
'QUERY_STRING': 'transport=websocket'}
|
|
start_response = mock.MagicMock()
|
|
|
|
def mock_handle(environ, start_response):
|
|
s.sockets['123'].closed = True
|
|
|
|
Socket().handle_get_request = mock_handle
|
|
s.handle_request(environ, start_response)
|
|
self.assertNotIn('123', s.sockets)
|
|
|
|
def test_connect_transport_invalid(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'transport=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
|
|
def test_connect_cors_headers(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
|
|
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
|
|
|
|
def test_connect_cors_allowed_origin(self):
|
|
s = server.Server(cors_allowed_origins=['a', 'b'])
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
|
|
'HTTP_ORIGIN': 'b'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Origin', 'b'), headers)
|
|
|
|
def test_connect_cors_not_allowed_origin(self):
|
|
s = server.Server(cors_allowed_origins=['a', 'b'])
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
|
|
'HTTP_ORIGIN': 'c'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertNotIn(('Access-Control-Allow-Origin', 'c'), headers)
|
|
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)
|
|
|
|
def test_connect_cors_headers_all_origins(self):
|
|
s = server.Server(cors_allowed_origins='*')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Origin', '*'), headers)
|
|
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
|
|
|
|
def test_connect_cors_headers_one_origin(self):
|
|
s = server.Server(cors_allowed_origins='a')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
|
|
'HTTP_ORIGIN': 'a'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Origin', 'a'), headers)
|
|
self.assertIn(('Access-Control-Allow-Credentials', 'true'), headers)
|
|
|
|
def test_connect_cors_headers_one_origin_not_allowed(self):
|
|
s = server.Server(cors_allowed_origins='a')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': '',
|
|
'HTTP_ORIGIN': 'b'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertNotIn(('Access-Control-Allow-Origin', 'b'), headers)
|
|
self.assertNotIn(('Access-Control-Allow-Origin', '*'), headers)
|
|
|
|
def test_connect_cors_no_credentials(self):
|
|
s = server.Server(cors_credentials=False)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertNotIn(('Access-Control-Allow-Credentials', 'true'), headers)
|
|
|
|
def test_cors_options(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'OPTIONS', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Methods', 'OPTIONS, GET, POST'),
|
|
headers)
|
|
|
|
def test_cors_request_headers(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET',
|
|
'HTTP_ACCESS_CONTROL_REQUEST_HEADERS': 'Foo, Bar'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
headers = start_response.call_args[0][1]
|
|
self.assertIn(('Access-Control-Allow-Headers', 'Foo, Bar'), headers)
|
|
|
|
def test_connect_event(self):
|
|
s = server.Server()
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
mock_event = mock.MagicMock()
|
|
s.on('connect')(mock_event)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
mock_event.assert_called_once_with('123', environ)
|
|
self.assertEqual(len(s.sockets), 1)
|
|
|
|
def test_connect_event_rejects(self):
|
|
s = server.Server()
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
mock_event = mock.MagicMock(return_value=False)
|
|
s.on('connect')(mock_event)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(len(s.sockets), 0)
|
|
self.assertEqual(start_response.call_args[0][0], '401 UNAUTHORIZED')
|
|
|
|
def test_method_not_found(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'PUT', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'405 METHOD NOT FOUND')
|
|
|
|
def test_get_request_with_bad_sid(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
|
|
def test_post_request_with_bad_sid(self):
|
|
s = server.Server()
|
|
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
|
|
def test_send(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
s.sockets['foo'] = mock_socket
|
|
s.send('foo', 'hello')
|
|
self.assertEqual(mock_socket.send.call_count, 1)
|
|
self.assertEqual(mock_socket.send.call_args[0][0].packet_type,
|
|
packet.MESSAGE)
|
|
self.assertEqual(mock_socket.send.call_args[0][0].data, 'hello')
|
|
|
|
def test_send_unknown_socket(self):
|
|
s = server.Server()
|
|
# just ensure no exceptions are raised
|
|
s.send('foo', 'hello')
|
|
|
|
def test_get_request(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'200 OK')
|
|
self.assertEqual(len(r), 1)
|
|
packets = payload.Payload(encoded_payload=r[0]).packets
|
|
self.assertEqual(len(packets), 1)
|
|
self.assertEqual(packets[0].packet_type, packet.MESSAGE)
|
|
|
|
def test_get_request_custom_response(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(side_effect=['resp'])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
self.assertEqual(s.handle_request(environ, start_response), 'resp')
|
|
|
|
def test_get_request_closes_socket(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
|
|
def mock_get_request(*args, **kwargs):
|
|
mock_socket.closed = True
|
|
return 'resp'
|
|
|
|
mock_socket.handle_get_request = mock_get_request
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
self.assertEqual(s.handle_request(environ, start_response), 'resp')
|
|
self.assertNotIn('foo', s.sockets)
|
|
|
|
def test_get_request_error(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(
|
|
side_effect=[exceptions.QueueEmpty])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
self.assertEqual(len(s.sockets), 0)
|
|
|
|
def test_post_request(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_post_request = mock.MagicMock()
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'200 OK')
|
|
|
|
def test_post_request_error(self):
|
|
s = server.Server()
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_post_request = mock.MagicMock(
|
|
side_effect=[exceptions.EngineIOError])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertEqual(start_response.call_args[0][0],
|
|
'400 BAD REQUEST')
|
|
self.assertNotIn('foo', s.sockets)
|
|
|
|
@staticmethod
|
|
def _gzip_decompress(b):
|
|
bytesio = six.BytesIO(b)
|
|
with gzip.GzipFile(fileobj=bytesio, mode='r') as gz:
|
|
return gz.read()
|
|
|
|
def test_gzip_compression(self):
|
|
s = server.Server(compression_threshold=0)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': 'gzip,deflate'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertIn(('Content-Encoding', 'gzip'),
|
|
start_response.call_args[0][1])
|
|
self._gzip_decompress(r[0])
|
|
|
|
def test_deflate_compression(self):
|
|
s = server.Server(compression_threshold=0)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': 'deflate;q=1,gzip'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
self.assertIn(('Content-Encoding', 'deflate'),
|
|
start_response.call_args[0][1])
|
|
zlib.decompress(r[0])
|
|
|
|
def test_gzip_compression_threshold(self):
|
|
s = server.Server(compression_threshold=1000)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': 'gzip'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
for header, value in start_response.call_args[0][1]:
|
|
self.assertNotEqual(header, 'Content-Encoding')
|
|
self.assertRaises(IOError, self._gzip_decompress, r[0])
|
|
|
|
def test_compression_disabled(self):
|
|
s = server.Server(http_compression=False, compression_threshold=0)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': 'gzip'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
for header, value in start_response.call_args[0][1]:
|
|
self.assertNotEqual(header, 'Content-Encoding')
|
|
self.assertRaises(IOError, self._gzip_decompress, r[0])
|
|
|
|
def test_compression_unknown(self):
|
|
s = server.Server(compression_threshold=0)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': 'rar'}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
for header, value in start_response.call_args[0][1]:
|
|
self.assertNotEqual(header, 'Content-Encoding')
|
|
self.assertRaises(IOError, self._gzip_decompress, r[0])
|
|
|
|
def test_compression_no_encoding(self):
|
|
s = server.Server(compression_threshold=0)
|
|
mock_socket = self._get_mock_socket()
|
|
mock_socket.handle_get_request = mock.MagicMock(return_value=[
|
|
packet.Packet(packet.MESSAGE, data='hello')])
|
|
s.sockets['foo'] = mock_socket
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_ACCEPT_ENCODING': ''}
|
|
start_response = mock.MagicMock()
|
|
r = s.handle_request(environ, start_response)
|
|
for header, value in start_response.call_args[0][1]:
|
|
self.assertNotEqual(header, 'Content-Encoding')
|
|
self.assertRaises(IOError, self._gzip_decompress, r[0])
|
|
|
|
def test_cookie(self):
|
|
s = server.Server(cookie='sid')
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
self.assertIn(('Set-Cookie', 'sid=123'),
|
|
start_response.call_args[0][1])
|
|
|
|
def test_no_cookie(self):
|
|
s = server.Server(cookie=None)
|
|
s._generate_id = mock.MagicMock(return_value='123')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
for header, value in start_response.call_args[0][1]:
|
|
self.assertNotEqual(header, 'Set-Cookie')
|
|
|
|
def test_logger(self):
|
|
s = server.Server(logger=False)
|
|
self.assertEqual(s.logger.getEffectiveLevel(), logging.ERROR)
|
|
s.logger.setLevel(logging.NOTSET)
|
|
s = server.Server(logger=True)
|
|
self.assertEqual(s.logger.getEffectiveLevel(), logging.INFO)
|
|
s.logger.setLevel(logging.WARNING)
|
|
s = server.Server(logger=True)
|
|
self.assertEqual(s.logger.getEffectiveLevel(), logging.WARNING)
|
|
s.logger.setLevel(logging.NOTSET)
|
|
my_logger = logging.Logger('foo')
|
|
s = server.Server(logger=my_logger)
|
|
self.assertEqual(s.logger, my_logger)
|
|
|
|
def test_custom_json(self):
|
|
# Warning: this test cannot run in parallel with other tests, as it
|
|
# changes the JSON encoding/decoding functions
|
|
|
|
class CustomJSON(object):
|
|
@staticmethod
|
|
def dumps(*args, **kwargs):
|
|
return '*** encoded ***'
|
|
|
|
@staticmethod
|
|
def loads(*args, **kwargs):
|
|
return '+++ decoded +++'
|
|
|
|
server.Server(json=CustomJSON)
|
|
pkt = packet.Packet(packet.MESSAGE, data={'foo': 'bar'})
|
|
self.assertEqual(pkt.encode(), b'4*** encoded ***')
|
|
pkt2 = packet.Packet(encoded_packet=pkt.encode())
|
|
self.assertEqual(pkt2.data, '+++ decoded +++')
|
|
|
|
# restore the default JSON module
|
|
packet.Packet.json = json
|
|
|
|
def test_background_tasks(self):
|
|
flag = {}
|
|
|
|
def bg_task():
|
|
flag['task'] = True
|
|
|
|
s = server.Server()
|
|
task = s.start_background_task(bg_task)
|
|
task.join()
|
|
self.assertIn('task', flag)
|
|
self.assertTrue(flag['task'])
|
|
|
|
def test_sleep(self):
|
|
s = server.Server()
|
|
t = time.time()
|
|
s.sleep(0.1)
|
|
self.assertTrue(time.time() - t > 0.1)
|
|
|
|
def test_service_task_started(self):
|
|
s = server.Server(monitor_clients=True)
|
|
s._service_task = mock.MagicMock()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': ''}
|
|
start_response = mock.MagicMock()
|
|
s.handle_request(environ, start_response)
|
|
s._service_task.assert_called_once_with()
|