Ejemplo n.º 1
0
class WebSocketXForwardedFor(unittest.TestCase):
    """
    Test that (only) a trusted X-Forwarded-For can replace the peer address.
    """

    def setUp(self):
        self.factory = WebSocketServerFactory()
        self.factory.setProtocolOptions(
            trustXForwardedFor=2
        )
        self.proto = WebSocketServerProtocol()
        self.proto.transport = StringTransport()
        self.proto.factory = self.factory
        self.proto.failHandshake = Mock()
        self.proto._connectionMade()

    def tearDown(self):
        for call in [
                self.proto.autoPingPendingCall,
                self.proto.autoPingTimeoutCall,
                self.proto.openHandshakeTimeoutCall,
                self.proto.closeHandshakeTimeoutCall,
        ]:
            if call is not None:
                call.cancel()

    def test_trusted_addresses(self):
        self.proto.data = b"\r\n".join([
            b'GET /ws HTTP/1.1',
            b'Host: www.example.com',
            b'Origin: http://www.example.com',
            b'Sec-WebSocket-Version: 13',
            b'Sec-WebSocket-Extensions: permessage-deflate',
            b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
            b'Connection: keep-alive, Upgrade',
            b'Upgrade: websocket',
            b'X-Forwarded-For: 1.2.3.4, 2.3.4.5, 111.222.33.44',
            b'\r\n',  # last string doesn't get a \r\n from join()
        ])
        self.proto.consumeData()

        self.assertEquals(
            self.proto.peer, "2.3.4.5",
            "The second address in X-Forwarded-For should have been picked as the peer address")
Ejemplo n.º 2
0
class WebSocketOriginMatching(unittest.TestCase):
    """
    Test that we match Origin: headers properly, when asked to
    """

    def setUp(self):
        self.factory = WebSocketServerFactory()
        self.factory.setProtocolOptions(
            allowedOrigins=[u'127.0.0.1:*', u'*.example.com:*']
        )
        self.proto = WebSocketServerProtocol()
        self.proto.transport = StringTransport()
        self.proto.factory = self.factory
        self.proto.failHandshake = Mock()
        self.proto._connectionMade()

    def tearDown(self):
        for call in [
                self.proto.autoPingPendingCall,
                self.proto.autoPingTimeoutCall,
                self.proto.openHandshakeTimeoutCall,
                self.proto.closeHandshakeTimeoutCall,
        ]:
            if call is not None:
                call.cancel()

    def test_match_full_origin(self):
        self.proto.data = b"\r\n".join([
            b'GET /ws HTTP/1.1',
            b'Host: www.example.com',
            b'Sec-WebSocket-Version: 13',
            b'Origin: http://www.example.com.malicious.com',
            b'Sec-WebSocket-Extensions: permessage-deflate',
            b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
            b'Connection: keep-alive, Upgrade',
            b'Upgrade: websocket',
            b'\r\n',  # last string doesn't get a \r\n from join()
        ])
        self.proto.consumeData()

        self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
        arg = self.proto.failHandshake.mock_calls[0][1][0]
        self.assertTrue('not allowed' in arg)

    def test_match_wrong_scheme_origin(self):
        # some monkey-business since we already did this in setUp, but
        # we want a different set of matching origins
        self.factory.setProtocolOptions(
            allowedOrigins=[u'http://*.example.com:*']
        )
        self.proto.allowedOriginsPatterns = self.factory.allowedOriginsPatterns
        self.proto.allowedOrigins = self.factory.allowedOrigins

        # the actual test
        self.factory.isSecure = False
        self.proto.data = b"\r\n".join([
            b'GET /ws HTTP/1.1',
            b'Host: www.example.com',
            b'Sec-WebSocket-Version: 13',
            b'Origin: https://www.example.com',
            b'Sec-WebSocket-Extensions: permessage-deflate',
            b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
            b'Connection: keep-alive, Upgrade',
            b'Upgrade: websocket',
            b'\r\n',  # last string doesn't get a \r\n from join()
        ])
        self.proto.consumeData()

        self.assertTrue(self.proto.failHandshake.called, "Handshake should have failed")
        arg = self.proto.failHandshake.mock_calls[0][1][0]
        self.assertTrue('not allowed' in arg)

    def test_match_origin_secure_scheme(self):
        self.factory.isSecure = True
        self.factory.port = 443
        self.proto.data = b"\r\n".join([
            b'GET /ws HTTP/1.1',
            b'Host: www.example.com',
            b'Sec-WebSocket-Version: 13',
            b'Origin: https://www.example.com',
            b'Sec-WebSocket-Extensions: permessage-deflate',
            b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
            b'Connection: keep-alive, Upgrade',
            b'Upgrade: websocket',
            b'\r\n',  # last string doesn't get a \r\n from join()
        ])
        self.proto.consumeData()

        self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")

    def test_match_origin_documentation_example(self):
        """
        Test the examples from the docs
        """
        self.factory.setProtocolOptions(
            allowedOrigins=['*://*.example.com:*']
        )
        self.factory.isSecure = True
        self.factory.port = 443
        self.proto.data = b"\r\n".join([
            b'GET /ws HTTP/1.1',
            b'Host: www.example.com',
            b'Sec-WebSocket-Version: 13',
            b'Origin: http://www.example.com',
            b'Sec-WebSocket-Extensions: permessage-deflate',
            b'Sec-WebSocket-Key: tXAxWFUqnhi86Ajj7dRY5g==',
            b'Connection: keep-alive, Upgrade',
            b'Upgrade: websocket',
            b'\r\n',  # last string doesn't get a \r\n from join()
        ])
        self.proto.consumeData()

        self.assertFalse(self.proto.failHandshake.called, "Handshake should have succeeded")

    def test_match_origin_examples(self):
        """
        All the example origins from RFC6454 (3.2.1)
        """
        # we're just testing the low-level function here...
        from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
        policy = wildcards2patterns(['*example.com:*'])

        # should parametrize test ...
        for url in ['http://example.com/', 'http://example.com:80/',
                    'http://example.com/path/file',
                    'http://example.com/;semi=true',
                    # 'http://example.com./',
                    '//example.com/',
                    'http://@example.com']:
            self.assertTrue(_is_same_origin(_url_to_origin(url), 'http', 80, policy), url)

    def test_match_origin_counter_examples(self):
        """
        All the example 'not-same' origins from RFC6454 (3.2.1)
        """
        # we're just testing the low-level function here...
        from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
        policy = wildcards2patterns(['example.com'])

        for url in ['http://ietf.org/', 'http://example.org/',
                    'https://example.com/', 'http://example.com:8080/',
                    'http://www.example.com/']:
            self.assertFalse(_is_same_origin(_url_to_origin(url), 'http', 80, policy))

    def test_match_origin_edge(self):
        # we're just testing the low-level function here...
        from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
        policy = wildcards2patterns(['http://*example.com:80'])

        self.assertTrue(
            _is_same_origin(_url_to_origin('http://example.com:80'), 'http', 80, policy)
        )
        self.assertFalse(
            _is_same_origin(_url_to_origin('http://example.com:81'), 'http', 81, policy)
        )
        self.assertFalse(
            _is_same_origin(_url_to_origin('https://example.com:80'), 'http', 80, policy)
        )

    def test_origin_from_url(self):
        from autobahn.websocket.protocol import _url_to_origin

        # basic function
        self.assertEqual(
            _url_to_origin('http://example.com'),
            ('http', 'example.com', 80)
        )
        # should lower-case scheme
        self.assertEqual(
            _url_to_origin('hTTp://example.com'),
            ('http', 'example.com', 80)
        )

    def test_origin_file(self):
        from autobahn.websocket.protocol import _url_to_origin
        self.assertEqual('null', _url_to_origin('file:///etc/passwd'))

    def test_origin_null(self):
        from autobahn.websocket.protocol import _is_same_origin, _url_to_origin
        self.assertEqual('null', _url_to_origin('null'))
        self.assertFalse(
            _is_same_origin(_url_to_origin('null'), 'http', 80, [])
        )
        self.assertFalse(
            _is_same_origin(_url_to_origin('null'), 'https', 80, [])
        )
        self.assertFalse(
            _is_same_origin(_url_to_origin('null'), '', 80, [])
        )
        self.assertFalse(
            _is_same_origin(_url_to_origin('null'), None, 80, [])
        )