class TestHttpProtocolHandler(unittest.TestCase):
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = mock_fromfd.return_value

        self.http_server_port = 65535
        self.flags = Proxy.initialize()
        self.flags.plugins = Proxy.load_plugins([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.mock_selector = mock_selector
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=self.flags)
        self.protocol_handler.initialize()

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_http_get(self, mock_server_connection: mock.Mock) -> None:
        server = mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0
        self.mock_selector_for_client_read_read_server_write(
            self.mock_selector, server)

        # Send request line
        assert self.http_server_port is not None
        self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' %
                                        self.http_server_port) + CRLF
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.LINE_RCVD)
        self.assertNotEqual(self.protocol_handler.request.state,
                            httpParserStates.COMPLETE)

        # Send headers and blank line, thus completing HTTP request
        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port, b'Accept: */*',
            b'Proxy-Connection: Keep-Alive', CRLF
        ])
        self.assert_data_queued(mock_server_connection, server)
        self.protocol_handler.run_once()
        server.flush.assert_called_once()

    def assert_tunnel_response(self, mock_server_connection: mock.Mock,
                               server: mock.Mock) -> None:
        self.protocol_handler.run_once()
        self.assertTrue(
            cast(HttpProxyPlugin, self.protocol_handler.
                 plugins['HttpProxyPlugin']).server is not None)
        self.assertEqual(self.protocol_handler.client.buffer[0],
                         HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
        mock_server_connection.assert_called_once()
        server.connect.assert_called_once()
        server.queue.assert_not_called()
        server.closed = False

        parser = HttpParser(httpParserTypes.RESPONSE_PARSER)
        parser.parse(self.protocol_handler.client.buffer[0].tobytes())
        self.assertEqual(parser.state, httpParserStates.COMPLETE)
        assert parser.code is not None
        self.assertEqual(int(parser.code), 200)

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_http_tunnel(self, mock_server_connection: mock.Mock) -> None:
        server = mock_server_connection.return_value
        server.connect.return_value = True

        def has_buffer() -> bool:
            return cast(bool, server.queue.called)

        server.has_buffer.side_effect = has_buffer
        self.mock_selector.return_value.select.side_effect = [
            [
                (selectors.SelectorKey(fileobj=self._conn,
                                       fd=self._conn.fileno,
                                       events=selectors.EVENT_READ,
                                       data=None), selectors.EVENT_READ),
            ],
            [
                (selectors.SelectorKey(fileobj=self._conn,
                                       fd=self._conn.fileno,
                                       events=0,
                                       data=None), selectors.EVENT_WRITE),
            ],
            [
                (selectors.SelectorKey(fileobj=self._conn,
                                       fd=self._conn.fileno,
                                       events=selectors.EVENT_READ,
                                       data=None), selectors.EVENT_READ),
            ],
            [
                (selectors.SelectorKey(fileobj=server.connection,
                                       fd=server.connection.fileno,
                                       events=0,
                                       data=None), selectors.EVENT_WRITE),
            ],
        ]

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Proxy-Connection: Keep-Alive', CRLF
        ])
        self.assert_tunnel_response(mock_server_connection, server)

        # Dispatch tunnel established response to client
        self.protocol_handler.run_once()
        self.assert_data_queued_to_server(server)

        self.protocol_handler.run_once()
        self.assertEqual(server.queue.call_count, 1)
        server.flush.assert_called_once()

    def test_proxy_connection_failed(self) -> None:
        self.mock_selector_for_client_read(self.mock_selector)
        self._conn.recv.return_value = CRLF.join([
            b'GET http://unknown.domain HTTP/1.1', b'Host: unknown.domain',
            CRLF
        ])
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.client.buffer[0],
                         ProxyConnectionFailed.RESPONSE_PKT)

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_proxy_authentication_failed(self, mock_fromfd: mock.Mock,
                                         mock_selector: mock.Mock) -> None:
        self._conn = mock_fromfd.return_value
        self.mock_selector_for_client_read(mock_selector)
        flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass'))
        flags.plugins = Proxy.load_plugins([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
            bytes_(PLUGIN_PROXY_AUTH),
        ])
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com',
            CRLF
        ])
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.client.buffer[0],
                         ProxyAuthenticationFailed.RESPONSE_PKT)

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_authenticated_proxy_http_get(self,
                                          mock_server_connection: mock.Mock,
                                          mock_fromfd: mock.Mock,
                                          mock_selector: mock.Mock) -> None:
        self._conn = mock_fromfd.return_value
        self.mock_selector_for_client_read(mock_selector)

        server = mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0

        flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass'))
        flags.plugins = Proxy.load_plugins([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()
        assert self.http_server_port is not None

        self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.INITIALIZED)

        self._conn.recv.return_value = CRLF
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.LINE_RCVD)

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port, b'Accept: */*',
            b'Proxy-Connection: Keep-Alive',
            b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF
        ])
        self.assert_data_queued(mock_server_connection, server)

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_authenticated_proxy_http_tunnel(self,
                                             mock_server_connection: mock.Mock,
                                             mock_fromfd: mock.Mock,
                                             mock_selector: mock.Mock) -> None:
        server = mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0
        self._conn = mock_fromfd.return_value
        self.mock_selector_for_client_read_read_server_write(
            mock_selector, server)

        flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass'))
        flags.plugins = Proxy.load_plugins(
            [bytes_(PLUGIN_HTTP_PROXY),
             bytes_(PLUGIN_WEB_SERVER)])

        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Proxy-Connection: Keep-Alive',
            b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF
        ])
        self.assert_tunnel_response(mock_server_connection, server)
        self.protocol_handler.client.flush()
        self.assert_data_queued_to_server(server)

        self.protocol_handler.run_once()
        server.flush.assert_called_once()

    def mock_selector_for_client_read_read_server_write(
            self, mock_selector: mock.Mock, server: mock.Mock) -> None:
        mock_selector.return_value.select.side_effect = [
            [
                (selectors.SelectorKey(fileobj=self._conn,
                                       fd=self._conn.fileno,
                                       events=selectors.EVENT_READ,
                                       data=None), selectors.EVENT_READ),
            ],
            [
                (selectors.SelectorKey(fileobj=self._conn,
                                       fd=self._conn.fileno,
                                       events=0,
                                       data=None), selectors.EVENT_READ),
            ],
            [
                (selectors.SelectorKey(fileobj=server.connection,
                                       fd=server.connection.fileno,
                                       events=0,
                                       data=None), selectors.EVENT_WRITE),
            ],
        ]

    def assert_data_queued(self, mock_server_connection: mock.Mock,
                           server: mock.Mock) -> None:
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.COMPLETE)
        mock_server_connection.assert_called_once()
        server.connect.assert_called_once()
        server.closed = False
        assert self.http_server_port is not None
        pkt = CRLF.join([
            b'GET / HTTP/1.1',
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port, b'Accept: */*',
            b'Via: 1.1 proxy.py v%s' % bytes_(__version__), CRLF
        ])
        server.queue.assert_called_once_with(pkt)
        server.buffer_size.return_value = len(pkt)

    def assert_data_queued_to_server(self, server: mock.Mock) -> None:
        assert self.http_server_port is not None
        self.assertEqual(self._conn.send.call_args[0][0],
                         HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)

        pkt = CRLF.join([
            b'GET / HTTP/1.1',
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__), CRLF
        ])

        self._conn.recv.return_value = pkt
        self.protocol_handler.run_once()

        server.queue.assert_called_once_with(pkt)
        server.buffer_size.return_value = len(pkt)
        server.flush.assert_not_called()

    def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None:
        mock_selector.return_value.select.return_value = [
            (selectors.SelectorKey(fileobj=self._conn,
                                   fd=self._conn.fileno,
                                   events=selectors.EVENT_READ,
                                   data=None), selectors.EVENT_READ),
        ]
Example #2
0
class TestWebServerPlugin(unittest.TestCase):
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = mock_fromfd.return_value
        self.mock_selector = mock_selector
        self.flags = Flags()
        self.flags.plugins = Flags.load_plugins(
            b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin'
        )
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=self.flags)
        self.protocol_handler.initialize()

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_pac_file_served_from_disk(self, mock_fromfd: mock.Mock,
                                       mock_selector: mock.Mock) -> None:
        pac_file = os.path.join(os.path.dirname(PROXY_PY_DIR), 'helper',
                                'proxy.pac')
        self._conn = mock_fromfd.return_value
        self.mock_selector_for_client_read(mock_selector)
        self.init_and_make_pac_file_request(pac_file)
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.COMPLETE)
        with open(pac_file, 'rb') as f:
            self._conn.send.called_once_with(
                build_http_response(200,
                                    reason=b'OK',
                                    headers={
                                        b'Content-Type':
                                        b'application/x-ns-proxy-autoconfig',
                                        b'Connection': b'close'
                                    },
                                    body=f.read()))

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_pac_file_served_from_buffer(self, mock_fromfd: mock.Mock,
                                         mock_selector: mock.Mock) -> None:
        self._conn = mock_fromfd.return_value
        self.mock_selector_for_client_read(mock_selector)
        pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }'
        self.init_and_make_pac_file_request(text_(pac_file_content))
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.COMPLETE)
        self._conn.send.called_once_with(
            build_http_response(200,
                                reason=b'OK',
                                headers={
                                    b'Content-Type':
                                    b'application/x-ns-proxy-autoconfig',
                                    b'Connection': b'close'
                                },
                                body=pac_file_content))

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_default_web_server_returns_404(self, mock_fromfd: mock.Mock,
                                            mock_selector: mock.Mock) -> None:
        self._conn = mock_fromfd.return_value
        mock_selector.return_value.select.return_value = [
            (selectors.SelectorKey(fileobj=self._conn,
                                   fd=self._conn.fileno,
                                   events=selectors.EVENT_READ,
                                   data=None), selectors.EVENT_READ),
        ]
        flags = Flags()
        flags.plugins = Flags.load_plugins(
            b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin'
        )
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET /hello HTTP/1.1',
            CRLF,
        ])
        self.protocol_handler.run_once()
        self.assertEqual(self.protocol_handler.request.state,
                         httpParserStates.COMPLETE)
        self.assertEqual(self.protocol_handler.client.buffer[0],
                         HttpWebServerPlugin.DEFAULT_404_RESPONSE)

    @unittest.skipIf(os.environ.get(
        'GITHUB_ACTIONS', False
    ), 'Disabled on GitHub actions because this test is flaky on GitHub infrastructure.'
                     )
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_static_web_server_serves(self, mock_fromfd: mock.Mock,
                                      mock_selector: mock.Mock) -> None:
        # Setup a static directory
        static_server_dir = os.path.join(tempfile.gettempdir(), 'static')
        index_file_path = os.path.join(static_server_dir, 'index.html')
        html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>'''
        os.makedirs(static_server_dir, exist_ok=True)
        with open(index_file_path, 'wb') as f:
            f.write(html_file_content)

        self._conn = mock_fromfd.return_value
        self._conn.recv.return_value = build_http_request(
            b'GET', b'/index.html')

        mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_WRITE,
                                    data=None), selectors.EVENT_WRITE)],
        ]

        flags = Flags(enable_static_server=True,
                      static_server_dir=static_server_dir)
        flags.plugins = Flags.load_plugins(
            b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin'
        )

        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()

        self.protocol_handler.run_once()
        self.protocol_handler.run_once()

        self.assertEqual(mock_selector.return_value.select.call_count, 2)
        self.assertEqual(self._conn.send.call_count, 1)
        encoded_html_file_content = gzip.compress(html_file_content)
        self.assertEqual(
            self._conn.send.call_args[0][0],
            build_http_response(200,
                                reason=b'OK',
                                headers={
                                    b'Content-Type':
                                    b'text/html',
                                    b'Cache-Control':
                                    b'max-age=86400',
                                    b'Content-Encoding':
                                    b'gzip',
                                    b'Connection':
                                    b'close',
                                    b'Content-Length':
                                    bytes_(len(encoded_html_file_content)),
                                },
                                body=encoded_html_file_content))

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_static_web_server_serves_404(self, mock_fromfd: mock.Mock,
                                          mock_selector: mock.Mock) -> None:
        self._conn = mock_fromfd.return_value
        self._conn.recv.return_value = build_http_request(
            b'GET', b'/not-found.html')

        mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_WRITE,
                                    data=None), selectors.EVENT_WRITE)],
        ]

        flags = Flags(enable_static_server=True)
        flags.plugins = Flags.load_plugins(
            b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin'
        )

        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()

        self.protocol_handler.run_once()
        self.protocol_handler.run_once()

        self.assertEqual(mock_selector.return_value.select.call_count, 2)
        self.assertEqual(self._conn.send.call_count, 1)
        self.assertEqual(self._conn.send.call_args[0][0],
                         HttpWebServerPlugin.DEFAULT_404_RESPONSE)

    @mock.patch('socket.fromfd')
    def test_on_client_connection_called_on_teardown(
            self, mock_fromfd: mock.Mock) -> None:
        flags = Flags()
        plugin = mock.MagicMock()
        flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]}
        self._conn = mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()
        plugin.assert_called()
        with mock.patch.object(self.protocol_handler,
                               'run_once') as mock_run_once:
            mock_run_once.return_value = True
            self.protocol_handler.run()
        self.assertTrue(self._conn.closed)
        plugin.return_value.on_client_connection_close.assert_called()

    def init_and_make_pac_file_request(self, pac_file: str) -> None:
        flags = Flags(pac_file=pac_file)
        flags.plugins = Flags.load_plugins(
            b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,'
            b'proxy.http.server.HttpWebServerPacFilePlugin')
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=flags)
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET / HTTP/1.1',
            CRLF,
        ])

    def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None:
        mock_selector.return_value.select.return_value = [
            (selectors.SelectorKey(fileobj=self._conn,
                                   fd=self._conn.fileno,
                                   events=selectors.EVENT_READ,
                                   data=None), selectors.EVENT_READ),
        ]
Example #3
0
class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase):
    @mock.patch('ssl.wrap_socket')
    @mock.patch('ssl.create_default_context')
    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    @mock.patch('proxy.http.proxy.server.gen_public_key')
    @mock.patch('proxy.http.proxy.server.gen_csr')
    @mock.patch('proxy.http.proxy.server.sign_csr')
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
              mock_sign_csr: mock.Mock, mock_gen_csr: mock.Mock,
              mock_gen_public_key: mock.Mock, mock_server_conn: mock.Mock,
              mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None:
        self.mock_fromfd = mock_fromfd
        self.mock_selector = mock_selector
        self.mock_sign_csr = mock_sign_csr
        self.mock_gen_csr = mock_gen_csr
        self.mock_gen_public_key = mock_gen_public_key
        self.mock_server_conn = mock_server_conn
        self.mock_ssl_context = mock_ssl_context
        self.mock_ssl_wrap = mock_ssl_wrap

        self.mock_sign_csr.return_value = True
        self.mock_gen_csr.return_value = True
        self.mock_gen_public_key.return_value = True

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = Flags(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
        )
        self.plugin = mock.MagicMock()

        plugin = get_plugin_by_test_name(self._testMethodName)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = mock.MagicMock(spec=socket.socket)
        mock_fromfd.return_value = self._conn
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=self.flags)
        self.protocol_handler.initialize()

        self.server = self.mock_server_conn.return_value

        self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection
        self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_wrap.return_value = self.client_ssl_connection

        def has_buffer() -> bool:
            return cast(bool, self.server.queue.called)

        def closed() -> bool:
            return not self.server.connect.called

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return self.server_ssl_connection
            return self._conn

        self.server.has_buffer.side_effect = has_buffer
        type(self.server).closed = mock.PropertyMock(side_effect=closed)
        type(self.server).connection = mock.PropertyMock(
            side_effect=mock_connection)

        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
            [(selectors.SelectorKey(fileobj=self.client_ssl_connection,
                                    fd=self.client_ssl_connection.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
            [(selectors.SelectorKey(fileobj=self.server_ssl_connection,
                                    fd=self.server_ssl_connection.fileno,
                                    events=selectors.EVENT_WRITE,
                                    data=None), selectors.EVENT_WRITE)],
            [(selectors.SelectorKey(fileobj=self.server_ssl_connection,
                                    fd=self.server_ssl_connection.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]

        # Connect
        def send(raw: bytes) -> int:
            return len(raw)

        self._conn.send.side_effect = send
        self._conn.recv.return_value = build_http_request(
            httpMethods.CONNECT, b'uni.corn:443')
        self.protocol_handler.run_once()

        self.assertEqual(self.mock_sign_csr.call_count, 1)
        self.assertEqual(self.mock_gen_csr.call_count, 1)
        self.assertEqual(self.mock_gen_public_key.call_count, 1)

        self.mock_server_conn.assert_called_once_with('uni.corn', 443)
        self.server.connect.assert_called()
        self.assertEqual(self.protocol_handler.client.connection,
                         self.client_ssl_connection)
        self.assertEqual(self.server.connection, self.server_ssl_connection)
        self._conn.send.assert_called_with(
            HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
        self.assertFalse(self.protocol_handler.client.has_buffer())

    def test_modify_post_data_plugin(self) -> None:
        original = b'{"key": "value"}'
        modified = b'{"key": "modified"}'
        self.client_ssl_connection.recv.return_value = build_http_request(
            b'POST',
            b'/',
            headers={
                b'Host': b'uni.corn',
                b'Content-Type': b'application/x-www-form-urlencoded',
                b'Content-Length': bytes_(len(original)),
            },
            body=original)
        self.protocol_handler.run_once()
        self.server.queue.assert_called_with(
            build_http_request(b'POST',
                               b'/',
                               headers={
                                   b'Host': b'uni.corn',
                                   b'Content-Length': bytes_(len(modified)),
                                   b'Content-Type': b'application/json',
                               },
                               body=modified))

    def test_man_in_the_middle_plugin(self) -> None:
        request = build_http_request(b'GET',
                                     b'/',
                                     headers={
                                         b'Host': b'uni.corn',
                                     })
        self.client_ssl_connection.recv.return_value = request

        # Client read
        self.protocol_handler.run_once()
        self.server.queue.assert_called_once_with(request)

        # Server write
        self.protocol_handler.run_once()
        self.server.flush.assert_called_once()

        # Server read
        self.server.recv.return_value = \
            build_http_response(
                httpStatusCodes.OK,
                reason=b'OK', body=b'Original Response From Upstream')
        self.protocol_handler.run_once()
        self.assertEqual(
            self.protocol_handler.client.buffer[0].tobytes(),
            build_http_response(httpStatusCodes.OK,
                                reason=b'OK',
                                body=b'Hello from man in the middle'))
Example #4
0
class TestHttpProxyPluginExamples(unittest.TestCase):
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None:
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = Flags()
        self.plugin = mock.MagicMock()

        self.mock_fromfd = mock_fromfd
        self.mock_selector = mock_selector

        plugin = get_plugin_by_test_name(self._testMethodName)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=self.flags)
        self.protocol_handler.initialize()

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_modify_post_data_plugin(self,
                                     mock_server_conn: mock.Mock) -> None:
        original = b'{"key": "value"}'
        modified = b'{"key": "modified"}'

        self._conn.recv.return_value = build_http_request(
            b'POST',
            b'http://httpbin.org/post',
            headers={
                b'Host': b'httpbin.org',
                b'Content-Type': b'application/x-www-form-urlencoded',
                b'Content-Length': bytes_(len(original)),
            },
            body=original)
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]

        self.protocol_handler.run_once()
        mock_server_conn.assert_called_with('httpbin.org', DEFAULT_HTTP_PORT)
        mock_server_conn.return_value.queue.assert_called_with(
            build_http_request(b'POST',
                               b'/post',
                               headers={
                                   b'Host': b'httpbin.org',
                                   b'Content-Length': bytes_(len(modified)),
                                   b'Content-Type': b'application/json',
                                   b'Via':
                                   b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                               },
                               body=modified))

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_proposed_rest_api_plugin(self,
                                      mock_server_conn: mock.Mock) -> None:
        path = b'/v1/users/'
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://%s%s' % (ProposedRestApiPlugin.API_SERVER, path),
            headers={
                b'Host': ProposedRestApiPlugin.API_SERVER,
            })
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]
        self.protocol_handler.run_once()

        mock_server_conn.assert_not_called()
        self.assertEqual(
            self.protocol_handler.client.buffer[0].tobytes(),
            build_http_response(
                httpStatusCodes.OK,
                reason=b'OK',
                headers={b'Content-Type': b'application/json'},
                body=bytes_(
                    json.dumps(ProposedRestApiPlugin.REST_API_SPEC[path]))))

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_redirect_to_custom_server_plugin(
            self, mock_server_conn: mock.Mock) -> None:
        request = build_http_request(b'GET',
                                     b'http://example.org/get',
                                     headers={
                                         b'Host': b'example.org',
                                     })
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]
        self.protocol_handler.run_once()

        upstream = urlparse.urlsplit(
            RedirectToCustomServerPlugin.UPSTREAM_SERVER)
        mock_server_conn.assert_called_with('localhost', 8899)
        mock_server_conn.return_value.queue.assert_called_with(
            build_http_request(b'GET',
                               upstream.path,
                               headers={
                                   b'Host': upstream.netloc,
                                   b'Via':
                                   b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                               }))

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_filter_by_upstream_host_plugin(
            self, mock_server_conn: mock.Mock) -> None:
        request = build_http_request(b'GET',
                                     b'http://google.com/',
                                     headers={
                                         b'Host': b'google.com',
                                     })
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]
        self.protocol_handler.run_once()

        mock_server_conn.assert_not_called()
        self.assertEqual(
            self.protocol_handler.client.buffer[0].tobytes(),
            build_http_response(
                status_code=httpStatusCodes.I_AM_A_TEAPOT,
                reason=b'I\'m a tea pot',
                headers={b'Connection': b'close'},
            ))

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_man_in_the_middle_plugin(self,
                                      mock_server_conn: mock.Mock) -> None:
        request = build_http_request(b'GET',
                                     b'http://super.secure/',
                                     headers={
                                         b'Host': b'super.secure',
                                     })
        self._conn.recv.return_value = request

        server = mock_server_conn.return_value
        server.connect.return_value = True

        def has_buffer() -> bool:
            return cast(bool, server.queue.called)

        def closed() -> bool:
            return not server.connect.called

        server.has_buffer.side_effect = has_buffer
        type(server).closed = mock.PropertyMock(side_effect=closed)

        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
            [(selectors.SelectorKey(fileobj=server.connection,
                                    fd=server.connection.fileno,
                                    events=selectors.EVENT_WRITE,
                                    data=None), selectors.EVENT_WRITE)],
            [(selectors.SelectorKey(fileobj=server.connection,
                                    fd=server.connection.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]

        # Client read
        self.protocol_handler.run_once()
        mock_server_conn.assert_called_with('super.secure', DEFAULT_HTTP_PORT)
        server.connect.assert_called_once()
        queued_request = \
            build_http_request(
                b'GET', b'/',
                headers={
                    b'Host': b'super.secure',
                    b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE
                }
            )
        server.queue.assert_called_once_with(queued_request)

        # Server write
        self.protocol_handler.run_once()
        server.flush.assert_called_once()

        # Server read
        server.recv.return_value = \
            build_http_response(
                httpStatusCodes.OK,
                reason=b'OK', body=b'Original Response From Upstream')
        self.protocol_handler.run_once()
        self.assertEqual(
            self.protocol_handler.client.buffer[0].tobytes(),
            build_http_response(httpStatusCodes.OK,
                                reason=b'OK',
                                body=b'Hello from man in the middle'))
Example #5
0
class TestHttpProxyPlugin(unittest.TestCase):

    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def setUp(self,
              mock_fromfd: mock.Mock,
              mock_selector: mock.Mock) -> None:
        self.mock_fromfd = mock_fromfd
        self.mock_selector = mock_selector

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = Proxy.initialize()
        self.plugin = mock.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.plugin]
        }
        self._conn = mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            TcpClientConnection(self._conn, self._addr),
            flags=self.flags)
        self.protocol_handler.initialize()

    def test_proxy_plugin_initialized(self) -> None:
        self.plugin.assert_called()

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_proxy_plugin_on_and_before_upstream_connection(
            self,
            mock_server_conn: mock.Mock) -> None:
        self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r
        self.plugin.return_value.handle_client_request.side_effect = lambda r: r

        self._conn.recv.return_value = build_http_request(
            b'GET', b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host'
            })
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(
                fileobj=self._conn,
                fd=self._conn.fileno,
                events=selectors.EVENT_READ,
                data=None), selectors.EVENT_READ)], ]

        self.protocol_handler.run_once()
        mock_server_conn.assert_called_with('upstream.host', DEFAULT_HTTP_PORT)
        self.plugin.return_value.before_upstream_connection.assert_called()
        self.plugin.return_value.handle_client_request.assert_called()

    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    def test_proxy_plugin_before_upstream_connection_can_teardown(
            self,
            mock_server_conn: mock.Mock) -> None:
        self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException()

        self._conn.recv.return_value = build_http_request(
            b'GET', b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host'
            })
        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(
                fileobj=self._conn,
                fd=self._conn.fileno,
                events=selectors.EVENT_READ,
                data=None), selectors.EVENT_READ)], ]

        self.protocol_handler.run_once()
        self.plugin.return_value.before_upstream_connection.assert_called()
        mock_server_conn.assert_not_called()
class TestHttpProxyTlsInterception(unittest.TestCase):
    @mock.patch('ssl.wrap_socket')
    @mock.patch('ssl.create_default_context')
    @mock.patch('proxy.http.proxy.server.TcpServerConnection')
    @mock.patch('subprocess.Popen')
    @mock.patch('selectors.DefaultSelector')
    @mock.patch('socket.fromfd')
    def test_e2e(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock,
                 mock_popen: mock.Mock, mock_server_conn: mock.Mock,
                 mock_ssl_context: mock.Mock,
                 mock_ssl_wrap: mock.Mock) -> None:
        host, port = uuid.uuid4().hex, 443
        netloc = '{0}:{1}'.format(host, port)

        self.mock_fromfd = mock_fromfd
        self.mock_selector = mock_selector
        self.mock_popen = mock_popen
        self.mock_server_conn = mock_server_conn
        self.mock_ssl_context = mock_ssl_context
        self.mock_ssl_wrap = mock_ssl_wrap

        ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection
        self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket)
        plain_connection = mock.MagicMock(spec=socket.socket)

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return ssl_connection
            return plain_connection

        type(self.mock_server_conn.return_value).connection = \
            mock.PropertyMock(side_effect=mock_connection)

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = Flags(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
        )
        self.plugin = mock.MagicMock()
        self.proxy_plugin = mock.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.proxy_plugin],
        }
        self._conn = mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(TcpClientConnection(
            self._conn, self._addr),
                                                    flags=self.flags)
        self.protocol_handler.initialize()

        self.plugin.assert_called()
        self.assertEqual(self.plugin.call_args[0][1], self.flags)
        self.assertEqual(self.plugin.call_args[0][2].connection, self._conn)
        self.proxy_plugin.assert_called()
        self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags)
        self.assertEqual(self.proxy_plugin.call_args[0][2].connection,
                         self._conn)

        connect_request = build_http_request(httpMethods.CONNECT,
                                             bytes_(netloc),
                                             headers={
                                                 b'Host': bytes_(netloc),
                                             })
        self._conn.recv.return_value = connect_request

        # Prepare mocked HttpProtocolHandlerPlugin
        self.plugin.return_value.get_descriptors.return_value = ([], [])
        self.plugin.return_value.write_to_descriptors.return_value = False
        self.plugin.return_value.read_from_descriptors.return_value = False
        self.plugin.return_value.on_client_data.side_effect = lambda raw: raw
        self.plugin.return_value.on_request_complete.return_value = False
        self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk
        self.plugin.return_value.on_client_connection_close.return_value = None

        # Prepare mocked HttpProxyBasePlugin
        self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r
        self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r

        self.mock_selector.return_value.select.side_effect = [
            [(selectors.SelectorKey(fileobj=self._conn,
                                    fd=self._conn.fileno,
                                    events=selectors.EVENT_READ,
                                    data=None), selectors.EVENT_READ)],
        ]
        self.protocol_handler.run_once()

        # Assert our mocked plugins invocations
        self.plugin.return_value.get_descriptors.assert_called()
        self.plugin.return_value.write_to_descriptors.assert_called_with([])
        self.plugin.return_value.on_client_data.assert_called_with(
            connect_request)
        self.plugin.return_value.on_request_complete.assert_called()
        self.plugin.return_value.read_from_descriptors.assert_called_with(
            [self._conn])
        self.proxy_plugin.return_value.before_upstream_connection.assert_called(
        )
        self.proxy_plugin.return_value.handle_client_request.assert_called()

        self.mock_server_conn.assert_called_with(host, port)
        self.mock_server_conn.return_value.connection.setblocking.assert_called_with(
            False)

        self.mock_ssl_context.assert_called_with(ssl.Purpose.SERVER_AUTH)
        # self.assertEqual(self.mock_ssl_context.return_value.options,
        # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 |
        # ssl.OP_NO_TLSv1_1)
        self.assertEqual(plain_connection.setblocking.call_count, 2)
        self.mock_ssl_context.return_value.wrap_socket.assert_called_with(
            plain_connection, server_hostname=host)
        # TODO: Assert Popen arguments, piping, success condition
        self.assertEqual(self.mock_popen.call_count, 2)
        self.assertEqual(ssl_connection.setblocking.call_count, 1)
        self.assertEqual(self.mock_server_conn.return_value._conn,
                         ssl_connection)
        self._conn.send.assert_called_with(
            HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT)
        assert self.flags.ca_cert_dir is not None
        self.mock_ssl_wrap.assert_called_with(
            self._conn,
            server_side=True,
            keyfile=self.flags.ca_signing_key_file,
            certfile=HttpProxyPlugin.generated_cert_file_path(
                self.flags.ca_cert_dir, host))
        self.assertEqual(self._conn.setblocking.call_count, 2)
        self.assertEqual(self.protocol_handler.client.connection,
                         self.mock_ssl_wrap.return_value)

        # Assert connection references for all other plugins is updated
        self.assertEqual(self.plugin.return_value.client._conn,
                         self.mock_ssl_wrap.return_value)
        self.assertEqual(self.proxy_plugin.return_value.client._conn,
                         self.mock_ssl_wrap.return_value)