425 lines
17 KiB
Python
425 lines
17 KiB
Python
import time
|
|
import unittest
|
|
|
|
import six
|
|
if six.PY3:
|
|
from unittest import mock
|
|
else:
|
|
import mock
|
|
|
|
from engineio import exceptions
|
|
from engineio import packet
|
|
from engineio import payload
|
|
from engineio import socket
|
|
|
|
|
|
class TestSocket(unittest.TestCase):
|
|
def setUp(self):
|
|
self.bg_tasks = []
|
|
|
|
def _get_mock_server(self):
|
|
mock_server = mock.Mock()
|
|
mock_server.ping_timeout = 0.2
|
|
mock_server.ping_interval = 0.2
|
|
mock_server.async_handlers = True
|
|
|
|
try:
|
|
import queue
|
|
except ImportError:
|
|
import Queue as queue
|
|
import threading
|
|
mock_server._async = {'threading': threading.Thread,
|
|
'queue': queue.Queue,
|
|
'websocket': None}
|
|
|
|
def bg_task(target, *args, **kwargs):
|
|
th = threading.Thread(target=target, args=args, kwargs=kwargs)
|
|
self.bg_tasks.append(th)
|
|
th.start()
|
|
return th
|
|
|
|
def create_queue(*args, **kwargs):
|
|
return queue.Queue(*args, **kwargs)
|
|
|
|
mock_server.start_background_task = bg_task
|
|
mock_server.create_queue = create_queue
|
|
mock_server.get_queue_empty_exception.return_value = queue.Empty
|
|
return mock_server
|
|
|
|
def _join_bg_tasks(self):
|
|
for task in self.bg_tasks:
|
|
task.join()
|
|
|
|
def test_create(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
self.assertEqual(s.server, mock_server)
|
|
self.assertEqual(s.sid, 'sid')
|
|
self.assertFalse(s.upgraded)
|
|
self.assertFalse(s.closed)
|
|
self.assertTrue(hasattr(s.queue, 'get'))
|
|
self.assertTrue(hasattr(s.queue, 'put'))
|
|
self.assertTrue(hasattr(s.queue, 'task_done'))
|
|
self.assertTrue(hasattr(s.queue, 'join'))
|
|
|
|
def test_empty_poll(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
self.assertRaises(exceptions.QueueEmpty, s.poll)
|
|
|
|
def test_poll(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
pkt1 = packet.Packet(packet.MESSAGE, data='hello')
|
|
pkt2 = packet.Packet(packet.MESSAGE, data='bye')
|
|
s.send(pkt1)
|
|
s.send(pkt2)
|
|
self.assertEqual(s.poll(), [pkt1, pkt2])
|
|
|
|
def test_ping_pong(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.receive(packet.Packet(packet.PING, data='abc'))
|
|
r = s.poll()
|
|
self.assertEqual(len(r), 1)
|
|
self.assertTrue(r[0].encode(), b'3abc')
|
|
|
|
def test_message_async_handler(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.receive(packet.Packet(packet.MESSAGE, data='foo'))
|
|
mock_server._trigger_event.assert_called_once_with('message', 'sid',
|
|
'foo',
|
|
run_async=True)
|
|
|
|
def test_message_sync_handler(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server.async_handlers = False
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.receive(packet.Packet(packet.MESSAGE, data='foo'))
|
|
mock_server._trigger_event.assert_called_once_with('message', 'sid',
|
|
'foo',
|
|
run_async=False)
|
|
|
|
def test_invalid_packet(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
self.assertRaises(exceptions.UnknownPacketError, s.receive,
|
|
packet.Packet(packet.OPEN))
|
|
|
|
def test_timeout(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server.ping_interval = -6
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.last_ping = time.time() - 1
|
|
s.close = mock.MagicMock()
|
|
s.send('packet')
|
|
s.close.assert_called_once_with(wait=False, abort=False)
|
|
|
|
def test_polling_read(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'foo')
|
|
pkt1 = packet.Packet(packet.MESSAGE, data='hello')
|
|
pkt2 = packet.Packet(packet.MESSAGE, data='bye')
|
|
s.send(pkt1)
|
|
s.send(pkt2)
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
packets = s.handle_get_request(environ, start_response)
|
|
self.assertEqual(packets, [pkt1, pkt2])
|
|
|
|
def test_polling_read_error(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'foo')
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo'}
|
|
start_response = mock.MagicMock()
|
|
self.assertRaises(exceptions.QueueEmpty, s.handle_get_request, environ,
|
|
start_response)
|
|
|
|
def test_polling_write(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server.max_http_buffer_size = 1000
|
|
pkt1 = packet.Packet(packet.MESSAGE, data='hello')
|
|
pkt2 = packet.Packet(packet.MESSAGE, data='bye')
|
|
p = payload.Payload(packets=[pkt1, pkt2]).encode()
|
|
s = socket.Socket(mock_server, 'foo')
|
|
s.receive = mock.MagicMock()
|
|
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
|
|
'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
|
|
s.handle_post_request(environ)
|
|
self.assertEqual(s.receive.call_count, 2)
|
|
|
|
def test_polling_write_too_large(self):
|
|
mock_server = self._get_mock_server()
|
|
pkt1 = packet.Packet(packet.MESSAGE, data='hello')
|
|
pkt2 = packet.Packet(packet.MESSAGE, data='bye')
|
|
p = payload.Payload(packets=[pkt1, pkt2]).encode()
|
|
mock_server.max_http_buffer_size = len(p) - 1
|
|
s = socket.Socket(mock_server, 'foo')
|
|
s.receive = mock.MagicMock()
|
|
environ = {'REQUEST_METHOD': 'POST', 'QUERY_STRING': 'sid=foo',
|
|
'CONTENT_LENGTH': len(p), 'wsgi.input': six.BytesIO(p)}
|
|
self.assertRaises(exceptions.ContentTooLongError,
|
|
s.handle_post_request, environ)
|
|
|
|
def test_upgrade_handshake(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'foo')
|
|
s._upgrade_websocket = mock.MagicMock()
|
|
environ = {'REQUEST_METHOD': 'GET', 'QUERY_STRING': 'sid=foo',
|
|
'HTTP_CONNECTION': 'Foo,Upgrade,Bar',
|
|
'HTTP_UPGRADE': 'websocket'}
|
|
start_response = mock.MagicMock()
|
|
s.handle_get_request(environ, start_response)
|
|
s._upgrade_websocket.assert_called_once_with(environ, start_response)
|
|
|
|
def test_upgrade(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server._async['websocket'] = mock.MagicMock()
|
|
mock_ws = mock.MagicMock()
|
|
mock_server._async['websocket'].return_value = mock_ws
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
environ = "foo"
|
|
start_response = "bar"
|
|
s._upgrade_websocket(environ, start_response)
|
|
mock_server._async['websocket'].assert_called_once_with(
|
|
s._websocket_handler)
|
|
mock_ws.assert_called_once_with(environ, start_response)
|
|
|
|
def test_upgrade_twice(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server._async['websocket'] = mock.MagicMock()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.upgraded = True
|
|
environ = "foo"
|
|
start_response = "bar"
|
|
self.assertRaises(IOError, s._upgrade_websocket,
|
|
environ, start_response)
|
|
|
|
def test_upgrade_packet(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.receive(packet.Packet(packet.UPGRADE))
|
|
r = s.poll()
|
|
self.assertEqual(len(r), 1)
|
|
self.assertEqual(r[0].encode(), packet.Packet(packet.NOOP).encode())
|
|
|
|
def test_upgrade_no_probe(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
ws = mock.MagicMock()
|
|
ws.wait.return_value = packet.Packet(packet.NOOP).encode(
|
|
always_bytes=False)
|
|
s._websocket_handler(ws)
|
|
self.assertFalse(s.upgraded)
|
|
|
|
def test_upgrade_no_upgrade_packet(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
ws = mock.MagicMock()
|
|
probe = six.text_type('probe')
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.PING, data=probe).encode(
|
|
always_bytes=False),
|
|
packet.Packet(packet.NOOP).encode(always_bytes=False)]
|
|
s._websocket_handler(ws)
|
|
ws.send.assert_called_once_with(packet.Packet(
|
|
packet.PONG, data=probe).encode(always_bytes=False))
|
|
self.assertEqual(s.queue.get().packet_type, packet.NOOP)
|
|
self.assertFalse(s.upgraded)
|
|
|
|
def test_close_packet(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.close = mock.MagicMock()
|
|
s.receive(packet.Packet(packet.CLOSE))
|
|
s.close.assert_called_once_with(wait=False, abort=True)
|
|
|
|
def test_invalid_packet_type(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
pkt = packet.Packet(packet_type=99)
|
|
self.assertRaises(exceptions.UnknownPacketError, s.receive, pkt)
|
|
|
|
def test_upgrade_not_supported(self):
|
|
mock_server = self._get_mock_server()
|
|
mock_server._async['websocket'] = None
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
environ = "foo"
|
|
start_response = "bar"
|
|
s._upgrade_websocket(environ, start_response)
|
|
mock_server._bad_request.assert_called_once_with()
|
|
|
|
def test_websocket_read_write(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = False
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
foo = six.text_type('foo')
|
|
bar = six.text_type('bar')
|
|
s.poll = mock.MagicMock(side_effect=[
|
|
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.MESSAGE, data=foo).encode(
|
|
always_bytes=False),
|
|
None]
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertTrue(s.connected)
|
|
self.assertTrue(s.upgraded)
|
|
self.assertEqual(mock_server._trigger_event.call_count, 2)
|
|
mock_server._trigger_event.assert_has_calls([
|
|
mock.call('message', 'sid', 'foo', run_async=True),
|
|
mock.call('disconnect', 'sid', run_async=False)])
|
|
ws.send.assert_called_with('4bar')
|
|
|
|
def test_websocket_upgrade_read_write(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
foo = six.text_type('foo')
|
|
bar = six.text_type('bar')
|
|
probe = six.text_type('probe')
|
|
s.poll = mock.MagicMock(side_effect=[
|
|
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.PING, data=probe).encode(
|
|
always_bytes=False),
|
|
packet.Packet(packet.UPGRADE).encode(always_bytes=False),
|
|
packet.Packet(packet.MESSAGE, data=foo).encode(
|
|
always_bytes=False),
|
|
None]
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertTrue(s.upgraded)
|
|
self.assertEqual(mock_server._trigger_event.call_count, 2)
|
|
mock_server._trigger_event.assert_has_calls([
|
|
mock.call('message', 'sid', 'foo', run_async=True),
|
|
mock.call('disconnect', 'sid', run_async=False)])
|
|
ws.send.assert_called_with('4bar')
|
|
|
|
def test_websocket_upgrade_with_payload(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
probe = six.text_type('probe')
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.PING, data=probe).encode(
|
|
always_bytes=False),
|
|
packet.Packet(packet.UPGRADE, data=b'2').encode(
|
|
always_bytes=False)]
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertTrue(s.upgraded)
|
|
|
|
def test_websocket_upgrade_with_backlog(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = True
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
probe = six.text_type('probe')
|
|
foo = six.text_type('foo')
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.PING, data=probe).encode(
|
|
always_bytes=False),
|
|
packet.Packet(packet.UPGRADE, data=b'2').encode(
|
|
always_bytes=False)]
|
|
s.upgrading = True
|
|
s.send(packet.Packet(packet.MESSAGE, data=foo))
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertTrue(s.upgraded)
|
|
self.assertFalse(s.upgrading)
|
|
self.assertEqual(s.packet_backlog, [])
|
|
ws.send.assert_called_with('4foo')
|
|
|
|
def test_websocket_read_write_wait_fail(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = False
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
foo = six.text_type('foo')
|
|
bar = six.text_type('bar')
|
|
s.poll = mock.MagicMock(side_effect=[
|
|
[packet.Packet(packet.MESSAGE, data=bar)],
|
|
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.MESSAGE, data=foo).encode(
|
|
always_bytes=False),
|
|
RuntimeError]
|
|
ws.send.side_effect = [None, RuntimeError]
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertEqual(s.closed, True)
|
|
|
|
def test_websocket_ignore_invalid_packet(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.connected = False
|
|
s.queue.join = mock.MagicMock(return_value=None)
|
|
foo = six.text_type('foo')
|
|
bar = six.text_type('bar')
|
|
s.poll = mock.MagicMock(side_effect=[
|
|
[packet.Packet(packet.MESSAGE, data=bar)], exceptions.QueueEmpty])
|
|
ws = mock.MagicMock()
|
|
ws.wait.side_effect = [
|
|
packet.Packet(packet.OPEN).encode(always_bytes=False),
|
|
packet.Packet(packet.MESSAGE, data=foo).encode(
|
|
always_bytes=False),
|
|
None]
|
|
s._websocket_handler(ws)
|
|
self._join_bg_tasks()
|
|
self.assertTrue(s.connected)
|
|
self.assertEqual(mock_server._trigger_event.call_count, 2)
|
|
mock_server._trigger_event.assert_has_calls([
|
|
mock.call('message', 'sid', foo, run_async=True),
|
|
mock.call('disconnect', 'sid', run_async=False)])
|
|
ws.send.assert_called_with('4bar')
|
|
|
|
def test_send_after_close(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.close(wait=False)
|
|
self.assertRaises(exceptions.SocketIsClosedError, s.send,
|
|
packet.Packet(packet.NOOP))
|
|
|
|
def test_close_after_close(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.close(wait=False)
|
|
self.assertTrue(s.closed)
|
|
self.assertEqual(mock_server._trigger_event.call_count, 1)
|
|
mock_server._trigger_event.assert_called_once_with('disconnect', 'sid',
|
|
run_async=False)
|
|
s.close()
|
|
self.assertEqual(mock_server._trigger_event.call_count, 1)
|
|
|
|
def test_close_and_wait(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.queue = mock.MagicMock()
|
|
s.close(wait=True)
|
|
s.queue.join.assert_called_once_with()
|
|
|
|
def test_close_without_wait(self):
|
|
mock_server = self._get_mock_server()
|
|
s = socket.Socket(mock_server, 'sid')
|
|
s.queue = mock.MagicMock()
|
|
s.close(wait=False)
|
|
self.assertEqual(s.queue.join.call_count, 0)
|