class ExceptionHandlingTests(unittest.TestCase): """ Tests that we format various exception variations properly during connectionLost """ def setUp(self): self.factory = WebSocketServerFactory() self.proto = WebSocketServerProtocol() self.proto.factory = self.factory self.proto.log = Mock() def tearDown(self): for call in [ self.proto.autoPingPendingCall, self.proto.autoPingTimeoutCall, self.proto.openHandshakeTimeoutCall, self.proto.closeHandshakeTimeoutCall, ]: if call is not None: call.cancel() def test_connection_done(self): # pretend we connected self.proto._connectionMade() self.proto.connectionLost(Failure(ConnectionDone())) messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls]) self.assertTrue('closed cleanly' in messages) def test_connection_aborted(self): # pretend we connected self.proto._connectionMade() self.proto.connectionLost(Failure(ConnectionAborted())) messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls]) self.assertTrue(' aborted ' in messages) def test_connection_lost(self): # pretend we connected self.proto._connectionMade() self.proto.connectionLost(Failure(ConnectionLost())) messages = ' '.join([str(x[1]) for x in self.proto.log.mock_calls]) self.assertTrue(' was lost ' in messages) def test_connection_lost_arg(self): # pretend we connected self.proto._connectionMade() self.proto.connectionLost(Failure(ConnectionLost("greetings"))) messages = ' '.join([str(x[1]) + str(x[2]) for x in self.proto.log.mock_calls]) self.assertTrue(' was lost ' in messages) self.assertTrue('greetings' in messages)
class WebSocketXForwardedFor(unittest.TestCase): """ Test that (only) a trusted X-Forwarded-For can replace the peer address. """ def setUp(self): self.factory = WebSocketServerFactory() self.factory.setProtocolOptions( trustXForwardedFor=2 ) self.proto = WebSocketServerProtocol() self.proto.transport = StringTransport() self.proto.factory = self.factory self.proto.failHandshake = Mock() self.proto._connectionMade() def tearDown(self): for call in [ self.proto.autoPingPendingCall, self.proto.autoPingTimeoutCall, self.proto.openHandshakeTimeoutCall, self.proto.closeHandshakeTimeoutCall, ]: if call is not None: call.cancel() def test_trusted_addresses(self): self.proto.data = b"\r\n".join([ b'GET /ws HTTP/1.1', b'Host: www.example.com', b'Origin: http://www.example.com', b'Sec-WebSocket-Version: 13', b'Sec-WebSocket-Extensions: permessage-deflate', b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', b'Connection: keep-alive, Upgrade', b'Upgrade: websocket', b'X-Forwarded-For: 1.2.3.4, 2.3.4.5, 111.222.33.44', b'\r\n', # last string doesn't get a \r\n from join() ]) self.proto.consumeData() self.assertEquals( self.proto.peer, "2.3.4.5", "The second address in X-Forwarded-For should have been picked as the peer address")
def test_handshake_fails(self): """ A handshake from a client only supporting Hixie-76 will fail. """ t = FakeTransport() f = WebSocketServerFactory() p = WebSocketServerProtocol() p.factory = f p.transport = t # from http://tools.ietf.org/html/draft-hixie-thewebsocketprotocol-76 http_request = b"GET /demo HTTP/1.1\r\nHost: example.com\r\nConnection: Upgrade\r\nSec-WebSocket-Key2: 12998 5 Y3 1 .P00\r\nSec-WebSocket-Protocol: sample\r\nUpgrade: WebSocket\r\nSec-WebSocket-Key1: 4 @1 46546xW%0l 1 5\r\nOrigin: http://example.com\r\n\r\n^n:ds[4U" p.openHandshakeTimeout = 0 p._connectionMade() p.data = http_request p.processHandshake() self.assertIn(b"HTTP/1.1 400", t._written) self.assertIn(b"Hixie76 protocol not supported", t._written)
class WebSocketOriginMatching(unittest.TestCase): """ Test that we match Origin: headers properly, when asked to """ def setUp(self): self.factory = WebSocketServerFactory() self.factory.setProtocolOptions( allowedOrigins=[u'127.0.0.1:*', u'*.example.com:*'] ) self.proto = WebSocketServerProtocol() self.proto.transport = StringTransport() self.proto.factory = self.factory self.proto.failHandshake = Mock() self.proto._connectionMade() def tearDown(self): for call in [ self.proto.autoPingPendingCall, self.proto.autoPingTimeoutCall, self.proto.openHandshakeTimeoutCall, self.proto.closeHandshakeTimeoutCall, ]: if call is not None: call.cancel() def test_match_full_origin(self): self.proto.data = b"\r\n".join([ b'GET /ws HTTP/1.1', b'Host: www.example.com', b'Sec-WebSocket-Version: 13', b'Origin: http://www.example.com.malicious.com', b'Sec-WebSocket-Extensions: permessage-deflate', b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', b'Connection: keep-alive, Upgrade', b'Upgrade: websocket', b'\r\n', # last string doesn't get a \r\n from join() ]) self.proto.consumeData() self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed") arg = self.proto.failHandshake.mock_calls[0][1][0] self.assertTrue('not allowed' in arg) def test_match_wrong_scheme_origin(self): # some monkey-business since we already did this in setUp, but # we want a different set of matching origins self.factory.setProtocolOptions( allowedOrigins=[u'http://*.example.com:*'] ) self.proto.allowedOriginsPatterns = self.factory.allowedOriginsPatterns self.proto.allowedOrigins = self.factory.allowedOrigins # the actual test self.factory.isSecure = False self.proto.data = b"\r\n".join([ b'GET /ws HTTP/1.1', b'Host: www.example.com', b'Sec-WebSocket-Version: 13', b'Origin: https://www.example.com', b'Sec-WebSocket-Extensions: permessage-deflate', b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', b'Connection: keep-alive, Upgrade', b'Upgrade: websocket', b'\r\n', # last string doesn't get a \r\n from join() ]) self.proto.consumeData() self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed") arg = self.proto.failHandshake.mock_calls[0][1][0] self.assertTrue('not allowed' in arg) def test_match_origin_secure_scheme(self): self.factory.isSecure = True self.factory.port = 443 self.proto.data = b"\r\n".join([ b'GET /ws HTTP/1.1', b'Host: www.example.com', b'Sec-WebSocket-Version: 13', b'Origin: https://www.example.com', b'Sec-WebSocket-Extensions: permessage-deflate', b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', b'Connection: keep-alive, Upgrade', b'Upgrade: websocket', b'\r\n', # last string doesn't get a \r\n from join() ]) self.proto.consumeData() self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded") def test_match_origin_documentation_example(self): """ Test the examples from the docs """ self.factory.setProtocolOptions( allowedOrigins=['*://*.example.com:*'] ) self.factory.isSecure = True self.factory.port = 443 self.proto.data = b"\r\n".join([ b'GET /ws HTTP/1.1', b'Host: www.example.com', b'Sec-WebSocket-Version: 13', b'Origin: http://www.example.com', b'Sec-WebSocket-Extensions: permessage-deflate', b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==', b'Connection: keep-alive, Upgrade', b'Upgrade: websocket', b'\r\n', # last string doesn't get a \r\n from join() ]) self.proto.consumeData() self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded") def test_match_origin_examples(self): """ All the example origins from RFC6454 (3.2.1) """ # we're just testing the low-level function here... from autobahn.websocket.protocol import _is_same_origin, _url_to_origin policy = wildcards2patterns(['*example.com:*']) # should parametrize test ... for url in ['http://example.com/', 'http://example.com:80/', 'http://example.com/path/file', 'http://example.com/;semi=true', # 'http://example.com./', '//example.com/', 'http://@example.com']: self.assertTrue(_is_same_origin(_url_to_origin(url), 'http', 80, policy), url) def test_match_origin_counter_examples(self): """ All the example 'not-same' origins from RFC6454 (3.2.1) """ # we're just testing the low-level function here... from autobahn.websocket.protocol import _is_same_origin, _url_to_origin policy = wildcards2patterns(['example.com']) for url in ['http://ietf.org/', 'http://example.org/', 'https://example.com/', 'http://example.com:8080/', 'http://www.example.com/']: self.assertFalse(_is_same_origin(_url_to_origin(url), 'http', 80, policy)) def test_match_origin_edge(self): # we're just testing the low-level function here... from autobahn.websocket.protocol import _is_same_origin, _url_to_origin policy = wildcards2patterns(['http://*example.com:80']) self.assertTrue( _is_same_origin(_url_to_origin('http://example.com:80'), 'http', 80, policy) ) self.assertFalse( _is_same_origin(_url_to_origin('http://example.com:81'), 'http', 81, policy) ) self.assertFalse( _is_same_origin(_url_to_origin('https://example.com:80'), 'http', 80, policy) ) def test_origin_from_url(self): from autobahn.websocket.protocol import _url_to_origin # basic function self.assertEqual( _url_to_origin('http://example.com'), ('http', 'example.com', 80) ) # should lower-case scheme self.assertEqual( _url_to_origin('hTTp://example.com'), ('http', 'example.com', 80) ) def test_origin_file(self): from autobahn.websocket.protocol import _url_to_origin self.assertEqual('null', _url_to_origin('file:///etc/passwd')) def test_origin_null(self): from autobahn.websocket.protocol import _is_same_origin, _url_to_origin self.assertEqual('null', _url_to_origin('null')) self.assertFalse( _is_same_origin(_url_to_origin('null'), 'http', 80, []) ) self.assertFalse( _is_same_origin(_url_to_origin('null'), 'https', 80, []) ) self.assertFalse( _is_same_origin(_url_to_origin('null'), '', 80, []) ) self.assertFalse( _is_same_origin(_url_to_origin('null'), None, 80, []) )