示例#1
0
    def _setUp(self, request: Any, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        adblock_json_path = Path(
            __file__,
        ).parent.parent.parent / "proxy" / "plugin" / "adblock.json"
        self.flags = FlagParser.initialize(
            input_args=[
                "--filtered-url-regex-config",
                str(adblock_json_path),
            ],
            threaded=True,
        )
        self.plugin = mock.MagicMock()

        plugin = get_plugin_by_test_name(request.param)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()
示例#2
0
 def _setUp(self, mocker: MockerFixture) -> None:
     self.mock_fromfd = mocker.patch('socket.fromfd')
     self.mock_selector = mocker.patch('selectors.DefaultSelector')
     self.fileno = 10
     self._addr = ('127.0.0.1', 54382)
     self._conn = self.mock_fromfd.return_value
     # Setup a static directory
     self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static')
     self.index_file_path = os.path.join(
         self.static_server_dir, 'index.html',
     )
     self.html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>'''
     os.makedirs(self.static_server_dir, exist_ok=True)
     with open(self.index_file_path, 'wb') as f:
         f.write(self.html_file_content)
     #
     flags = FlagParser.initialize(
         enable_static_server=True,
         static_server_dir=self.static_server_dir,
         threaded=True,
     )
     flags.plugins = Plugins.load([
         bytes_(PLUGIN_HTTP_PROXY),
         bytes_(PLUGIN_WEB_SERVER),
     ])
     self.protocol_handler = HttpProtocolHandler(
         HttpClientConnection(self._conn, self._addr),
         flags=flags,
     )
     self.protocol_handler.initialize()
示例#3
0
 async def test_proxy_authentication_failed(self) -> None:
     self._conn = self.mock_fromfd.return_value
     mock_selector_for_client_read(self)
     flags = FlagParser.initialize(
         auth_code=base64.b64encode(b'user:pass'),
         threaded=True,
     )
     flags.plugins = Plugins.load([
         bytes_(PLUGIN_HTTP_PROXY),
         bytes_(PLUGIN_WEB_SERVER),
         bytes_(PLUGIN_PROXY_AUTH),
     ])
     self.protocol_handler = HttpProtocolHandler(
         HttpClientConnection(self._conn, self._addr),
         flags=flags,
     )
     self.protocol_handler.initialize()
     self._conn.recv.return_value = CRLF.join([
         b'GET http://abhinavsingh.com HTTP/1.1',
         b'Host: abhinavsingh.com',
         CRLF,
     ])
     await self.protocol_handler._run_once()
     self.assertEqual(
         self.protocol_handler.work.buffer[0],
         PROXY_AUTH_FAILED_RESPONSE_PKT,
     )
示例#4
0
 def _setUp(self, request: Any, mocker: MockerFixture) -> None:
     self.mock_fromfd = mocker.patch('socket.fromfd')
     self.mock_selector = mocker.patch('selectors.DefaultSelector')
     self.fileno = 10
     self._addr = ('127.0.0.1', 54382)
     self._conn = self.mock_fromfd.return_value
     self.pac_file = request.param
     if isinstance(self.pac_file, str):
         with open(self.pac_file, 'rb') as f:
             self.expected_response = f.read()
     else:
         self.expected_response = PAC_FILE_CONTENT
     self.flags = FlagParser.initialize(
         pac_file=self.pac_file, threaded=True,
     )
     self.flags.plugins = Plugins.load([
         bytes_(PLUGIN_HTTP_PROXY),
         bytes_(PLUGIN_WEB_SERVER),
         bytes_(PLUGIN_PAC_FILE),
     ])
     self.protocol_handler = HttpProtocolHandler(
         HttpClientConnection(self._conn, self._addr),
         flags=self.flags,
     )
     self.protocol_handler.initialize()
     self._conn.recv.return_value = CRLF.join([
         b'GET / HTTP/1.1',
         CRLF,
     ])
     mock_selector_for_client_read(self)
示例#5
0
class TestWebServerPlugin(Assertions):

    @pytest.fixture(autouse=True)   # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = self.mock_fromfd.return_value
        self.flags = FlagParser.initialize(threaded=True)
        self.flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_default_web_server_returns_404(self) -> None:
        self._conn = self.mock_fromfd.return_value
        self.mock_selector.return_value.select.return_value = [
            (
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            ),
        ]
        flags = FlagParser.initialize(threaded=True)
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET /hello HTTP/1.1',
            CRLF,
        ])
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.COMPLETE,
        )
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            NOT_FOUND_RESPONSE_PKT,
        )
示例#6
0
class TestWebServerPluginWithPacFilePlugin(Assertions):

    @pytest.fixture(
        autouse=True, params=[
            PAC_FILE_PATH,
            PAC_FILE_CONTENT,
        ],
    )  # type: ignore[misc]
    def _setUp(self, request: Any, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = self.mock_fromfd.return_value
        self.pac_file = request.param
        if isinstance(self.pac_file, str):
            with open(self.pac_file, 'rb') as f:
                self.expected_response = f.read()
        else:
            self.expected_response = PAC_FILE_CONTENT
        self.flags = FlagParser.initialize(
            pac_file=self.pac_file, threaded=True,
        )
        self.flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
            bytes_(PLUGIN_PAC_FILE),
        ])
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET / HTTP/1.1',
            CRLF,
        ])
        mock_selector_for_client_read(self)

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_pac_file_served_from_disk(self) -> None:
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.COMPLETE,
        )
        self._conn.send.called_once_with(
            build_http_response(
                200,
                reason=b'OK',
                headers={
                    b'Content-Type': b'application/x-ns-proxy-autoconfig',
                },
                body=self.expected_response,
                conn_close=True,
            ),
        )
示例#7
0
    async def test_authenticated_proxy_http_get(self) -> None:
        self._conn = self.mock_fromfd.return_value
        mock_selector_for_client_read(self)

        server = self.mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0

        flags = FlagParser.initialize(
            auth_code=base64.b64encode(b'user:pass'),
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()
        assert self.http_server_port is not None

        self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.INITIALIZED,
        )

        self._conn.recv.return_value = CRLF
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.LINE_RCVD,
        )

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port,
            b'Accept: */*',
            httpHeaders.PROXY_CONNECTION + b': Keep-Alive',
            httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz',
            CRLF,
        ])
        await self.assert_data_queued(server)
示例#8
0
 def _setUp(self, mocker: MockerFixture) -> None:
     self.mock_fromfd = mocker.patch('socket.fromfd')
     self.mock_selector = mocker.patch('selectors.DefaultSelector')
     self.fileno = 10
     self._addr = ('127.0.0.1', 54382)
     self._conn = self.mock_fromfd.return_value
     self.flags = FlagParser.initialize(threaded=True)
     self.flags.plugins = Plugins.load([
         bytes_(PLUGIN_HTTP_PROXY),
         bytes_(PLUGIN_WEB_SERVER),
     ])
     self.protocol_handler = HttpProtocolHandler(
         HttpClientConnection(self._conn, self._addr),
         flags=self.flags,
     )
     self.protocol_handler.initialize()
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ["--basic-auth", "user:pass"],
            threaded=True,
        )
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()
示例#10
0
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection',
        )
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_fromfd = mocker.patch('socket.fromfd')

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(threaded=True)
        self.plugin = mocker.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()
示例#11
0
    async def test_authenticated_proxy_http_tunnel(self) -> None:
        server = self.mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0
        self._conn = self.mock_fromfd.return_value
        self.mock_selector_for_client_read_and_server_write(server)

        flags = FlagParser.initialize(
            auth_code=base64.b64encode(b'user:pass'),
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            httpHeaders.PROXY_CONNECTION + b': Keep-Alive',
            httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz',
            CRLF,
        ])
        await self.assert_tunnel_response(server)
        self.protocol_handler.work.flush()
        await self.assert_data_queued_to_server(server)

        await self.protocol_handler._run_once()
        server.flush.assert_called_once()
示例#12
0
 async def test_default_web_server_returns_404(self) -> None:
     self._conn = self.mock_fromfd.return_value
     self.mock_selector.return_value.select.return_value = [
         (
             selectors.SelectorKey(
                 fileobj=self._conn.fileno(),
                 fd=self._conn.fileno(),
                 events=selectors.EVENT_READ,
                 data=None,
             ),
             selectors.EVENT_READ,
         ),
     ]
     flags = FlagParser.initialize(threaded=True)
     flags.plugins = Plugins.load([
         bytes_(PLUGIN_HTTP_PROXY),
         bytes_(PLUGIN_WEB_SERVER),
     ])
     self.protocol_handler = HttpProtocolHandler(
         HttpClientConnection(self._conn, self._addr),
         flags=flags,
     )
     self.protocol_handler.initialize()
     self._conn.recv.return_value = CRLF.join([
         b'GET /hello HTTP/1.1',
         CRLF,
     ])
     await self.protocol_handler._run_once()
     self.assertEqual(
         self.protocol_handler.request.state,
         httpParserStates.COMPLETE,
     )
     self.assertEqual(
         self.protocol_handler.work.buffer[0],
         NOT_FOUND_RESPONSE_PKT,
     )
示例#13
0
def test_on_client_connection_called_on_teardown(mocker: MockerFixture) -> None:
    plugin = mocker.MagicMock()
    mock_fromfd = mocker.patch('socket.fromfd')
    flags = FlagParser.initialize(threaded=True)
    flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]}
    _conn = mock_fromfd.return_value
    _addr = ('127.0.0.1', 54382)
    protocol_handler = HttpProtocolHandler(
        HttpClientConnection(_conn, _addr),
        flags=flags,
    )
    protocol_handler.initialize()
    plugin.assert_not_called()
    mock_run_once = mocker.patch.object(protocol_handler, '_run_once')
    mock_run_once.return_value = True
    protocol_handler.run()
    assert _conn.closed
示例#14
0
class TestHttpProxyPluginExamples(Assertions):
    @pytest.fixture(autouse=True)  # type: ignore[misc]
    def _setUp(self, request: Any, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        adblock_json_path = Path(
            __file__,
        ).parent.parent.parent / "proxy" / "plugin" / "adblock.json"
        self.flags = FlagParser.initialize(
            input_args=[
                "--filtered-url-regex-config",
                str(adblock_json_path),
            ],
            threaded=True,
        )
        self.plugin = mock.MagicMock()

        plugin = get_plugin_by_test_name(request.param)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_modify_post_data_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_modify_post_data_plugin(self) -> None:
        original = b'{"key": "value"}'
        modified = b'{"key": "modified"}'

        self._conn.recv.return_value = build_http_request(
            b'POST',
            b'http://httpbin.org/post',
            headers={
                b'Host': b'httpbin.org',
                b'Content-Type': b'application/x-www-form-urlencoded',
                b'Content-Length': bytes_(len(original)),
            },
            body=original,
            no_ua=True,
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_with(
            'httpbin.org',
            DEFAULT_HTTP_PORT,
        )
        self.mock_server_conn.return_value.queue.assert_called_with(
            build_http_request(
                b'POST',
                b'/post',
                headers={
                    b'Host': b'httpbin.org',
                    b'Content-Type': b'application/json',
                    b'Content-Length': bytes_(len(modified)),
                    b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                },
                body=modified,
                no_ua=True,
            ), )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_proposed_rest_api_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_proposed_rest_api_plugin(self) -> None:
        path = b'/v1/users/'
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://%s%s' % (
                ProposedRestApiPlugin.API_SERVER,
                path,
            ),
            headers={
                b'Host': ProposedRestApiPlugin.API_SERVER,
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()

        self.mock_server_conn.assert_not_called()
        response = HttpParser(httpParserTypes.RESPONSE_PARSER)
        response.parse(self.protocol_handler.work.buffer[0])
        assert response.body
        self.assertEqual(
            response.header(b'content-type'),
            b'application/json',
        )
        self.assertEqual(
            response.header(b'content-encoding'),
            b'gzip',
        )
        self.assertEqual(
            gzip.decompress(response.body),
            bytes_(json.dumps(ProposedRestApiPlugin.REST_API_SPEC[path], ), ),
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_redirect_to_custom_server_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_redirect_to_custom_server_plugin(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://example.org/get',
            headers={
                b'Host': b'example.org',
            },
            no_ua=True,
        )
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()

        upstream = urlparse.urlsplit(
            RedirectToCustomServerPlugin.UPSTREAM_SERVER, )
        self.mock_server_conn.assert_called_with('localhost', 8899)
        self.mock_server_conn.return_value.queue.assert_called_with(
            build_http_request(
                b'GET',
                upstream.path,
                headers={
                    b'Host': upstream.netloc,
                    b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                },
                no_ua=True,
            ), )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_redirect_to_custom_server_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_redirect_to_custom_server_plugin_skips_https(self) -> None:
        request = build_http_request(
            b'CONNECT',
            b'jaxl.com:443',
            headers={
                b'Host': b'jaxl.com:443',
            },
        )
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_with('jaxl.com', 443)
        self.assertEqual(
            self.protocol_handler.work.buffer[0].tobytes(),
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_filter_by_upstream_host_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_filter_by_upstream_host_plugin(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://facebook.com/',
            headers={
                b'Host': b'facebook.com',
            },
        )
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()

        self.mock_server_conn.assert_not_called()
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            build_http_response(
                status_code=httpStatusCodes.I_AM_A_TEAPOT,
                reason=b'I\'m a tea pot',
                conn_close=True,
            ),
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_man_in_the_middle_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_man_in_the_middle_plugin(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://super.secure/',
            headers={
                b'Host': b'super.secure',
            },
            no_ua=True,
        )
        self._conn.recv.return_value = request

        server = self.mock_server_conn.return_value
        server.connect.return_value = True

        def has_buffer() -> bool:
            return cast(bool, server.queue.called)

        def closed() -> bool:
            return not server.connect.called

        server.has_buffer.side_effect = has_buffer
        type(server).closed = mock.PropertyMock(side_effect=closed)

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=server.connection.fileno(),
                    fd=server.connection.fileno(),
                    events=selectors.EVENT_WRITE,
                    data=None,
                ),
                selectors.EVENT_WRITE,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=server.connection.fileno(),
                    fd=server.connection.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        # Client read
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_with(
            'super.secure',
            DEFAULT_HTTP_PORT,
        )
        server.connect.assert_called_once()
        queued_request = \
            build_http_request(
                b'GET', b'/',
                headers={
                    b'Host': b'super.secure',
                    b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                },
                no_ua=True,
            )
        server.queue.assert_called_once()
        print(server.queue.call_args_list[0][0][0].tobytes())
        print(queued_request)
        self.assertEqual(server.queue.call_args_list[0][0][0], queued_request)

        # Server write
        await self.protocol_handler._run_once()
        server.flush.assert_called_once()

        # Server read
        server.recv.return_value = \
            build_http_response(
                httpStatusCodes.OK,
                reason=b'OK',
                body=b'Original Response From Upstream',
            )
        await self.protocol_handler._run_once()
        response = HttpParser(httpParserTypes.RESPONSE_PARSER)
        response.parse(self.protocol_handler.work.buffer[0])
        assert response.body
        self.assertEqual(
            gzip.decompress(response.body),
            b'Hello from man in the middle',
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_filter_by_url_regex_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_filter_by_url_regex_plugin(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://www.facebook.com/tr/',
            headers={
                b'Host': b'www.facebook.com',
            },
        )
        self._conn.recv.return_value = request
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()

        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            build_http_response(
                status_code=httpStatusCodes.NOT_FOUND,
                reason=b'Blocked',
                conn_close=True,
            ),
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_shortlink_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_shortlink_plugin(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://t/',
            headers={
                b'Host': b't',
            },
        )
        self._conn.recv.return_value = request

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0].tobytes(),
            build_http_response(
                status_code=httpStatusCodes.SEE_OTHER,
                reason=b'See Other',
                headers={
                    b'Location': b'http://twitter.com/',
                },
                conn_close=True,
            ),
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_shortlink_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_shortlink_plugin_unknown(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://unknown/',
            headers={
                b'Host': b'unknown',
            },
        )
        self._conn.recv.return_value = request

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0].tobytes(),
            NOT_FOUND_RESPONSE_PKT,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        "_setUp",
        (('test_shortlink_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_shortlink_plugin_external(self) -> None:
        request = build_http_request(
            b'GET',
            b'http://jaxl.com/',
            headers={
                b'Host': b'jaxl.com',
            },
            no_ua=True,
        )
        self._conn.recv.return_value = request

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_once_with('jaxl.com', 80)
        self.mock_server_conn.return_value.queue.assert_called_with(
            build_http_request(
                b'GET',
                b'/',
                headers={
                    b'Host': b'jaxl.com',
                    b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE,
                },
                no_ua=True,
            ), )
        self.assertFalse(self.protocol_handler.work.has_buffer())
class TestHttpProxyAuthFailed(Assertions):
    @pytest.fixture(autouse=True)  # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ["--basic-auth", "user:pass"],
            threaded=True,
        )
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_auth_fails_without_cred(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_not_called()
        self.assertEqual(self.protocol_handler.work.has_buffer(), True)
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            PROXY_AUTH_FAILED_RESPONSE_PKT,
        )
        self._conn.send.assert_not_called()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_auth_fails_with_invalid_cred(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
                httpHeaders.PROXY_AUTHORIZATION: b'Basic hello',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_not_called()
        self.assertEqual(self.protocol_handler.work.has_buffer(), True)
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            PROXY_AUTH_FAILED_RESPONSE_PKT,
        )
        self._conn.send.assert_not_called()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_auth_works_with_valid_cred(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
                httpHeaders.PROXY_AUTHORIZATION: b'Basic dXNlcjpwYXNz',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_once()
        self.assertEqual(self.protocol_handler.work.has_buffer(), False)

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_auth_works_with_mixed_case_basic_string(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET',
            b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
                httpHeaders.PROXY_AUTHORIZATION: b'bAsIc dXNlcjpwYXNz',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]
        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_called_once()
        self.assertEqual(self.protocol_handler.work.has_buffer(), False)
示例#16
0
class TestStaticWebServerPlugin(Assertions):

    @pytest.fixture(autouse=True)   # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = self.mock_fromfd.return_value
        # Setup a static directory
        self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static')
        self.index_file_path = os.path.join(
            self.static_server_dir, 'index.html',
        )
        self.html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>'''
        os.makedirs(self.static_server_dir, exist_ok=True)
        with open(self.index_file_path, 'wb') as f:
            f.write(self.html_file_content)
        #
        flags = FlagParser.initialize(
            enable_static_server=True,
            static_server_dir=self.static_server_dir,
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_static_web_server_serves(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET', b'/index.html',
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_WRITE,
                    data=None,
                ),
                selectors.EVENT_WRITE,
            )],
        ]
        await self.protocol_handler._run_once()
        await self.protocol_handler._run_once()

        self.assertEqual(self.mock_selector.return_value.select.call_count, 2)
        self.assertEqual(self._conn.send.call_count, 1)
        encoded_html_file_content = gzip.compress(self.html_file_content)

        # parse response and verify
        response = HttpParser(httpParserTypes.RESPONSE_PARSER)
        response.parse(self._conn.send.call_args[0][0])
        self.assertEqual(response.code, b'200')
        self.assertEqual(response.header(b'content-type'), b'text/html')
        self.assertEqual(response.header(b'cache-control'), b'max-age=86400')
        self.assertEqual(response.header(b'content-encoding'), b'gzip')
        self.assertEqual(response.header(b'connection'), b'close')
        self.assertEqual(
            response.header(b'content-length'),
            bytes_(len(encoded_html_file_content)),
        )
        assert response.body
        self.assertEqual(
            gzip.decompress(response.body),
            self.html_file_content,
        )

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_static_web_server_serves_404(self) -> None:
        self._conn.recv.return_value = build_http_request(
            b'GET', b'/not-found.html',
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_WRITE,
                    data=None,
                ),
                selectors.EVENT_WRITE,
            )],
        ]

        await self.protocol_handler._run_once()
        await self.protocol_handler._run_once()

        self.assertEqual(self.mock_selector.return_value.select.call_count, 2)
        self.assertEqual(self._conn.send.call_count, 1)
        self.assertEqual(
            self._conn.send.call_args[0][0],
            NOT_FOUND_RESPONSE_PKT,
        )
示例#17
0
class TestHttpProtocolHandlerWithoutServerMock(Assertions):
    @pytest.fixture(autouse=True)  # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = self.mock_fromfd.return_value

        self.http_server_port = 65535
        self.flags = FlagParser.initialize(threaded=True)
        self.flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_connection_failed(self) -> None:
        mock_selector_for_client_read(self)
        self._conn.recv.return_value = CRLF.join([
            b'GET http://unknown.domain HTTP/1.1',
            b'Host: unknown.domain',
            CRLF,
        ])
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            BAD_GATEWAY_RESPONSE_PKT,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_authentication_failed(self) -> None:
        self._conn = self.mock_fromfd.return_value
        mock_selector_for_client_read(self)
        flags = FlagParser.initialize(
            auth_code=base64.b64encode(b'user:pass'),
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
            bytes_(PLUGIN_PROXY_AUTH),
        ])
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()
        self._conn.recv.return_value = CRLF.join([
            b'GET http://abhinavsingh.com HTTP/1.1',
            b'Host: abhinavsingh.com',
            CRLF,
        ])
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            PROXY_AUTH_FAILED_RESPONSE_PKT,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_bails_out_for_unknown_schemes(self) -> None:
        mock_selector_for_client_read(self)
        self._conn.recv.return_value = CRLF.join([
            b'REQMOD icap://icap-server.net/server?arg=87 ICAP/1.0',
            b'Host: icap-server.net',
            CRLF,
        ])
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            BAD_REQUEST_RESPONSE_PKT,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_proxy_bails_out_for_sip_request_lines(self) -> None:
        mock_selector_for_client_read(self)
        self._conn.recv.return_value = CRLF.join([
            b'OPTIONS sip:nm SIP/2.0',
            b'Accept: application/sdp',
            CRLF,
        ])
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            BAD_REQUEST_RESPONSE_PKT,
        )
示例#18
0
class TestHttpProtocolHandler(Assertions):
    @pytest.fixture(autouse=True)  # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_server_connection = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self._conn = self.mock_fromfd.return_value

        self.http_server_port = 65535
        self.flags = FlagParser.initialize(threaded=True)
        self.flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_http_get(self) -> None:
        server = self.mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0

        self.mock_selector_for_client_read_and_server_write(server)

        # Send request line
        assert self.http_server_port is not None
        self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' %
                                        self.http_server_port) + CRLF

        await self.protocol_handler._run_once()

        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.LINE_RCVD,
        )
        self.assertNotEqual(
            self.protocol_handler.request.state,
            httpParserStates.COMPLETE,
        )

        # Send headers and blank line, thus completing HTTP request
        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port,
            b'Accept: */*',
            b'Proxy-Connection: Keep-Alive',
            CRLF,
        ])
        await self.assert_data_queued(server)
        await self.protocol_handler._run_once()
        server.flush.assert_called_once()

    async def assert_tunnel_response(
        self,
        server: mock.Mock,
    ) -> None:
        await self.protocol_handler._run_once()
        self.assertTrue(
            cast(
                HttpProxyPlugin,
                self.protocol_handler.plugin,
            ).upstream is not None, )
        self.assertEqual(
            self.protocol_handler.work.buffer[0],
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
        )
        self.mock_server_connection.assert_called_once()
        server.connect.assert_called_once()
        server.queue.assert_not_called()
        server.closed = False

        parser = HttpParser(httpParserTypes.RESPONSE_PARSER)
        parser.parse(self.protocol_handler.work.buffer[0])
        self.assertEqual(parser.state, httpParserStates.COMPLETE)
        assert parser.code is not None
        self.assertEqual(int(parser.code), 200)

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_http_tunnel(self) -> None:
        server = self.mock_server_connection.return_value
        server.connect.return_value = True

        def has_buffer() -> bool:
            return cast(bool, server.queue.called)

        server.has_buffer.side_effect = has_buffer
        self.mock_selector.return_value.select.side_effect = [
            [
                (
                    selectors.SelectorKey(
                        fileobj=self._conn.fileno(),
                        fd=self._conn.fileno(),
                        events=selectors.EVENT_READ,
                        data=None,
                    ),
                    selectors.EVENT_READ,
                ),
            ],
            [
                (
                    selectors.SelectorKey(
                        fileobj=self._conn.fileno(),
                        fd=self._conn.fileno(),
                        events=0,
                        data=None,
                    ),
                    selectors.EVENT_WRITE,
                ),
            ],
            [
                (
                    selectors.SelectorKey(
                        fileobj=self._conn.fileno(),
                        fd=self._conn.fileno(),
                        events=selectors.EVENT_READ,
                        data=None,
                    ),
                    selectors.EVENT_READ,
                ),
            ],
            [
                (
                    selectors.SelectorKey(
                        fileobj=server.connection.fileno(),
                        fd=server.connection.fileno(),
                        events=0,
                        data=None,
                    ),
                    selectors.EVENT_WRITE,
                ),
            ],
        ]

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Proxy-Connection: Keep-Alive',
            CRLF,
        ])
        await self.assert_tunnel_response(server)

        # Dispatch tunnel established response to client
        await self.protocol_handler._run_once()
        await self.assert_data_queued_to_server(server)

        await self.protocol_handler._run_once()
        self.assertEqual(server.queue.call_count, 1)
        server.flush.assert_called_once()

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_authenticated_proxy_http_get(self) -> None:
        self._conn = self.mock_fromfd.return_value
        mock_selector_for_client_read(self)

        server = self.mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0

        flags = FlagParser.initialize(
            auth_code=base64.b64encode(b'user:pass'),
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()
        assert self.http_server_port is not None

        self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.INITIALIZED,
        )

        self._conn.recv.return_value = CRLF
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.LINE_RCVD,
        )

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port,
            b'Accept: */*',
            httpHeaders.PROXY_CONNECTION + b': Keep-Alive',
            httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz',
            CRLF,
        ])
        await self.assert_data_queued(server)

    @pytest.mark.asyncio  # type: ignore[misc]
    async def test_authenticated_proxy_http_tunnel(self) -> None:
        server = self.mock_server_connection.return_value
        server.connect.return_value = True
        server.buffer_size.return_value = 0
        self._conn = self.mock_fromfd.return_value
        self.mock_selector_for_client_read_and_server_write(server)

        flags = FlagParser.initialize(
            auth_code=base64.b64encode(b'user:pass'),
            threaded=True,
        )
        flags.plugins = Plugins.load([
            bytes_(PLUGIN_HTTP_PROXY),
            bytes_(PLUGIN_WEB_SERVER),
        ])

        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=flags,
        )
        self.protocol_handler.initialize()

        assert self.http_server_port is not None
        self._conn.recv.return_value = CRLF.join([
            b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port,
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            httpHeaders.PROXY_CONNECTION + b': Keep-Alive',
            httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz',
            CRLF,
        ])
        await self.assert_tunnel_response(server)
        self.protocol_handler.work.flush()
        await self.assert_data_queued_to_server(server)

        await self.protocol_handler._run_once()
        server.flush.assert_called_once()

    def mock_selector_for_client_read_and_server_write(
        self,
        server: mock.Mock,
    ) -> None:
        self.mock_selector.return_value.select.side_effect = [
            [
                (
                    selectors.SelectorKey(
                        fileobj=self._conn.fileno(),
                        fd=self._conn.fileno(),
                        events=selectors.EVENT_READ,
                        data=None,
                    ),
                    selectors.EVENT_READ,
                ),
            ],
            [
                (
                    selectors.SelectorKey(
                        fileobj=self._conn.fileno(),
                        fd=self._conn.fileno(),
                        events=0,
                        data=None,
                    ),
                    selectors.EVENT_READ,
                ),
            ],
            [
                (
                    selectors.SelectorKey(
                        fileobj=server.connection.fileno(),
                        fd=server.connection.fileno(),
                        events=0,
                        data=None,
                    ),
                    selectors.EVENT_WRITE,
                ),
            ],
        ]

    async def assert_data_queued(
        self,
        server: mock.Mock,
    ) -> None:
        await self.protocol_handler._run_once()
        self.assertEqual(
            self.protocol_handler.request.state,
            httpParserStates.COMPLETE,
        )
        self.mock_server_connection.assert_called_once()
        server.connect.assert_called_once()
        server.closed = False
        assert self.http_server_port is not None
        pkt = CRLF.join([
            b'GET / HTTP/1.1',
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            b'Host: localhost:%d' % self.http_server_port,
            b'Accept: */*',
            b'Via: 1.1 proxy.py v%s' % bytes_(__version__),
            CRLF,
        ])
        server.queue.assert_called_once()
        self.assertEqual(server.queue.call_args_list[0][0][0], pkt)
        server.buffer_size.return_value = len(pkt)

    async def assert_data_queued_to_server(self, server: mock.Mock) -> None:
        assert self.http_server_port is not None
        self.assertEqual(
            self._conn.send.call_args[0][0],
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT,
        )

        pkt = CRLF.join([
            b'GET / HTTP/1.1',
            b'Host: localhost:%d' % self.http_server_port,
            b'User-Agent: proxy.py/%s' % bytes_(__version__),
            CRLF,
        ])

        self._conn.recv.return_value = pkt
        await self.protocol_handler._run_once()

        server.queue.assert_called_once_with(pkt)
        server.buffer_size.return_value = len(pkt)
        server.flush.assert_not_called()
示例#19
0
    async def test_e2e(self, mocker: MockerFixture) -> None:
        host, port = uuid.uuid4().hex, 443
        netloc = '{0}:{1}'.format(host, port)

        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr')
        self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr')
        self.mock_gen_public_key = mocker.patch(
            'proxy.http.proxy.server.gen_public_key', )
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )
        self.mock_ssl_context = mocker.patch('ssl.create_default_context')
        self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')

        self.mock_sign_csr.return_value = True
        self.mock_gen_csr.return_value = True
        self.mock_gen_public_key.return_value = True

        ssl_connection = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection
        self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket)
        plain_connection = mock.MagicMock(spec=socket.socket)

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return ssl_connection
            return plain_connection

        # Do not mock the original wrap method
        self.mock_server_conn.return_value.wrap.side_effect = \
            lambda x, y: TcpServerConnection.wrap(
                self.mock_server_conn.return_value, x, y,
            )

        type(self.mock_server_conn.return_value).connection = \
            mock.PropertyMock(side_effect=mock_connection)

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
            threaded=True,
        )
        self.plugin = mock.MagicMock()
        self.proxy_plugin = mock.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.proxy_plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

        self.plugin.assert_not_called()
        self.proxy_plugin.assert_not_called()

        connect_request = build_http_request(
            httpMethods.CONNECT,
            bytes_(netloc),
            headers={
                b'Host': bytes_(netloc),
            },
        )
        self._conn.recv.return_value = connect_request

        async def asyncReturnBool(val: bool) -> bool:
            return val

        # Prepare mocked HttpProtocolHandlerPlugin
        # self.plugin.return_value.get_descriptors.return_value = ([], [])
        # self.plugin.return_value.write_to_descriptors.return_value = asyncReturnBool(False)
        # self.plugin.return_value.read_from_descriptors.return_value = asyncReturnBool(False)
        # self.plugin.return_value.on_client_data.side_effect = lambda raw: raw
        # self.plugin.return_value.on_request_complete.return_value = False
        # self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk
        # self.plugin.return_value.on_client_connection_close.return_value = None

        # Prepare mocked HttpProxyBasePlugin
        self.proxy_plugin.return_value.write_to_descriptors.return_value = \
            asyncReturnBool(False)
        self.proxy_plugin.return_value.read_from_descriptors.return_value = \
            asyncReturnBool(False)
        self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r
        self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r
        self.proxy_plugin.return_value.resolve_dns.return_value = None, None

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()

        # Assert correct plugin was initialized
        self.plugin.assert_not_called()
        self.proxy_plugin.assert_called_once()
        self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags)
        # Actual call arg must be `_conn` object
        # but because internally the reference is updated
        # we assert it against `mock_ssl_wrap` which is
        # called during proxy plugin initialization
        # for interception
        self.assertEqual(
            self.proxy_plugin.call_args[0][2].connection,
            self.mock_ssl_wrap.return_value,
        )

        # Assert our mocked plugins invocations
        # self.plugin.return_value.get_descriptors.assert_called()
        # self.plugin.return_value.write_to_descriptors.assert_called_with([])
        # # on_client_data is only called after initial request has completed
        # self.plugin.return_value.on_client_data.assert_not_called()
        # self.plugin.return_value.on_request_complete.assert_called()
        # self.plugin.return_value.read_from_descriptors.assert_called_with([
        #     self._conn.fileno(),
        # ])
        self.proxy_plugin.return_value.before_upstream_connection.assert_called(
        )
        self.proxy_plugin.return_value.handle_client_request.assert_called()

        self.mock_server_conn.assert_called_with(host, port)
        self.mock_server_conn.return_value.connection.setblocking.assert_called_with(
            False, )

        self.mock_ssl_context.assert_called_with(
            ssl.Purpose.SERVER_AUTH,
            cafile=str(DEFAULT_CA_FILE),
        )
        # self.assertEqual(self.mock_ssl_context.return_value.options,
        #                  ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 |
        #                  ssl.OP_NO_TLSv1_1)
        self.assertEqual(plain_connection.setblocking.call_count, 2)
        self.mock_ssl_context.return_value.wrap_socket.assert_called_with(
            plain_connection,
            server_hostname=host,
        )
        self.assertEqual(self.mock_sign_csr.call_count, 1)
        self.assertEqual(self.mock_gen_csr.call_count, 1)
        self.assertEqual(self.mock_gen_public_key.call_count, 1)
        self.assertEqual(ssl_connection.setblocking.call_count, 1)
        self.assertEqual(
            self.mock_server_conn.return_value._conn,
            ssl_connection,
        )
        self._conn.send.assert_called_with(
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, )
        assert self.flags.ca_cert_dir is not None
        self.mock_ssl_wrap.assert_called_with(
            self._conn,
            server_side=True,
            keyfile=self.flags.ca_signing_key_file,
            certfile=HttpProxyPlugin.generated_cert_file_path(
                self.flags.ca_cert_dir,
                host,
            ),
            ssl_version=ssl.PROTOCOL_TLS,
        )
        self.assertEqual(self._conn.setblocking.call_count, 2)
        self.assertEqual(
            self.protocol_handler.work.connection,
            self.mock_ssl_wrap.return_value,
        )

        # Assert connection references for all other plugins is updated
        # self.assertEqual(
        #     self.plugin.return_value.client._conn,
        #     self.mock_ssl_wrap.return_value,
        # )
        self.assertEqual(
            self.proxy_plugin.return_value.client._conn,
            self.mock_ssl_wrap.return_value,
        )
示例#20
0
class TestHttpProxyPlugin:

    @pytest.fixture(autouse=True)   # type: ignore[misc]
    def _setUp(self, mocker: MockerFixture) -> None:
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection',
        )
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_fromfd = mocker.patch('socket.fromfd')

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(threaded=True)
        self.plugin = mocker.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

    def test_proxy_plugin_not_initialized_unless_first_request_completes(self) -> None:
        self.plugin.assert_not_called()

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_proxy_plugin_on_and_before_upstream_connection(self) -> None:
        self.plugin.return_value.write_to_descriptors.return_value = False
        self.plugin.return_value.read_from_descriptors.return_value = False
        self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r
        self.plugin.return_value.handle_client_request.side_effect = lambda r: r
        self.plugin.return_value.resolve_dns.return_value = None, None

        self._conn.recv.return_value = build_http_request(
            b'GET', b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()

        self.mock_server_conn.assert_called_with(
            'upstream.host', DEFAULT_HTTP_PORT,
        )
        self.plugin.return_value.before_upstream_connection.assert_called()
        self.plugin.return_value.handle_client_request.assert_called()

    @pytest.mark.asyncio    # type: ignore[misc]
    async def test_proxy_plugin_before_upstream_connection_can_teardown(self) -> None:
        self.plugin.return_value.write_to_descriptors.return_value = False
        self.plugin.return_value.read_from_descriptors.return_value = False
        self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException()

        self._conn.recv.return_value = build_http_request(
            b'GET', b'http://upstream.host/not-found.html',
            headers={
                b'Host': b'upstream.host',
            },
        )
        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()
        self.mock_server_conn.assert_not_called()
        self.plugin.return_value.before_upstream_connection.assert_called()

    def test_proxy_plugin_plugins_can_teardown_from_write_to_descriptors(self) -> None:
        pass

    def test_proxy_plugin_retries_on_ssl_want_write_error(self) -> None:
        pass

    def test_proxy_plugin_broken_pipe_error_on_write_will_teardown(self) -> None:
        pass

    def test_proxy_plugin_plugins_can_teardown_from_read_from_descriptors(self) -> None:
        pass

    def test_proxy_plugin_retries_on_ssl_want_read_error(self) -> None:
        pass

    def test_proxy_plugin_timeout_error_on_read_will_teardown(self) -> None:
        pass

    def test_proxy_plugin_invokes_handle_pipeline_response(self) -> None:
        pass

    def test_proxy_plugin_invokes_on_access_log(self) -> None:
        pass

    def test_proxy_plugin_skips_server_teardown_when_client_closes_and_server_never_initialized(self) -> None:
        pass

    def test_proxy_plugin_invokes_handle_client_data(self) -> None:
        pass

    def test_proxy_plugin_handles_pipeline_response(self) -> None:
        pass

    def test_proxy_plugin_invokes_resolve_dns(self) -> None:
        pass

    def test_proxy_plugin_require_both_host_port_to_connect(self) -> None:
        pass
class TestHttpProxyPluginExamplesWithTlsInterception(Assertions):
    @pytest.fixture(autouse=True)  # type: ignore[misc]
    def _setUp(self, request: Any, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr')
        self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr')
        self.mock_gen_public_key = mocker.patch(
            'proxy.http.proxy.server.gen_public_key', )
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )
        self.mock_ssl_context = mocker.patch('ssl.create_default_context')
        self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')

        self.mock_sign_csr.return_value = True
        self.mock_gen_csr.return_value = True
        self.mock_gen_public_key.return_value = True

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
            threaded=True,
        )
        self.plugin = mocker.MagicMock()

        plugin = get_plugin_by_test_name(request.param)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = mocker.MagicMock(spec=socket.socket)
        self.mock_fromfd.return_value = self._conn
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

        self.server = self.mock_server_conn.return_value

        self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection
        self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_wrap.return_value = self.client_ssl_connection

        def has_buffer() -> bool:
            return cast(bool, self.server.queue.called)

        def closed() -> bool:
            return not self.server.connect.called

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return self.server_ssl_connection
            return self._conn

        # Do not mock the original wrap method
        self.server.wrap.side_effect = \
            lambda x, y, as_non_blocking: TcpServerConnection.wrap(
                self.server, x, y, as_non_blocking=as_non_blocking,
            )

        self.server.has_buffer.side_effect = has_buffer
        type(self.server).closed = mocker.PropertyMock(side_effect=closed)
        type(self.server, ).connection = mocker.PropertyMock(
            side_effect=mock_connection, )

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.client_ssl_connection.fileno(),
                    fd=self.client_ssl_connection.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.server_ssl_connection.fileno(),
                    fd=self.server_ssl_connection.fileno(),
                    events=selectors.EVENT_WRITE,
                    data=None,
                ),
                selectors.EVENT_WRITE,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.server_ssl_connection.fileno(),
                    fd=self.server_ssl_connection.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        # Connect
        def send(raw: bytes) -> int:
            return len(raw)

        self._conn.send.side_effect = send
        self._conn.recv.return_value = build_http_request(
            httpMethods.CONNECT,
            b'uni.corn:443',
            no_ua=True,
        )

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        '_setUp',
        (('test_modify_post_data_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_modify_post_data_plugin(self) -> None:
        await self.protocol_handler._run_once()

        self.assertEqual(self.mock_sign_csr.call_count, 1)
        self.assertEqual(self.mock_gen_csr.call_count, 1)
        self.assertEqual(self.mock_gen_public_key.call_count, 1)

        self.mock_server_conn.assert_called_once_with('uni.corn', 443)
        self.server.connect.assert_called()
        self.assertEqual(
            self.protocol_handler.work.connection,
            self.client_ssl_connection,
        )
        self.assertEqual(self.server.connection, self.server_ssl_connection)
        self._conn.send.assert_called_with(
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, )
        self.assertFalse(self.protocol_handler.work.has_buffer())

        #
        original = b'{"key": "value"}'
        modified = b'{"key": "modified"}'
        self.client_ssl_connection.recv.return_value = build_http_request(
            b'POST',
            b'/',
            headers={
                b'Host': b'uni.corn',
                b'Content-Length': bytes_(len(original)),
                b'Content-Type': b'application/x-www-form-urlencoded',
            },
            body=original,
            no_ua=True,
        )
        await self.protocol_handler._run_once()
        self.server.queue.assert_called_once()
        # pkt = build_http_request(
        #     b'POST', b'/',
        #     headers={
        #         b'Host': b'uni.corn',
        #         b'Content-Length': bytes_(len(modified)),
        #         b'Content-Type': b'application/json',
        #     },
        #     body=modified,
        # )
        response = HttpParser.response(
            self.server.queue.call_args_list[0][0][0].tobytes(), )
        self.assertEqual(response.body, modified)

    @pytest.mark.asyncio  # type: ignore[misc]
    @pytest.mark.parametrize(
        '_setUp',
        (('test_man_in_the_middle_plugin'), ),
        indirect=True,
    )  # type: ignore[misc]
    async def test_man_in_the_middle_plugin(self) -> None:
        await self.protocol_handler._run_once()

        self.assertEqual(self.mock_sign_csr.call_count, 1)
        self.assertEqual(self.mock_gen_csr.call_count, 1)
        self.assertEqual(self.mock_gen_public_key.call_count, 1)

        self.mock_server_conn.assert_called_once_with('uni.corn', 443)
        self.server.connect.assert_called()
        self.assertEqual(
            self.protocol_handler.work.connection,
            self.client_ssl_connection,
        )
        self.assertEqual(self.server.connection, self.server_ssl_connection)
        self._conn.send.assert_called_with(
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, )
        self.assertFalse(self.protocol_handler.work.has_buffer())
        #
        request = build_http_request(
            b'GET',
            b'/',
            headers={
                b'Host': b'uni.corn',
            },
            no_ua=True,
        )
        self.client_ssl_connection.recv.return_value = request

        # Client read
        await self.protocol_handler._run_once()
        self.server.queue.assert_called_once_with(request)

        # Server write
        await self.protocol_handler._run_once()
        self.server.flush.assert_called_once()

        # Server read
        self.server.recv.return_value = okResponse(
            content=b'Original Response From Upstream', )
        await self.protocol_handler._run_once()
        response = HttpParser(httpParserTypes.RESPONSE_PARSER)
        response.parse(self.protocol_handler.work.buffer[0])
        assert response.body
        self.assertEqual(
            gzip.decompress(response.body),
            b'Hello from man in the middle',
        )
示例#22
0
    async def test_e2e(self, mocker: MockerFixture) -> None:
        host, port = uuid.uuid4().hex, 443
        netloc = '{0}:{1}'.format(host, port)

        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr')
        self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr')
        self.mock_gen_public_key = mocker.patch(
            'proxy.http.proxy.server.gen_public_key', )
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )
        self.mock_sign_csr.return_value = True
        self.mock_gen_csr.return_value = True
        self.mock_gen_public_key.return_value = True

        # Used for server side wrapping
        self.mock_ssl_context = mocker.patch('ssl.create_default_context')
        upstream_tls_sock = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock

        # Used for client wrapping
        self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')
        client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_wrap.return_value = client_tls_sock

        plain_connection = mock.MagicMock(spec=socket.socket)

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return upstream_tls_sock
            return plain_connection

        # Do not mock the original wrap method
        self.mock_server_conn.return_value.wrap.side_effect = \
            lambda x, y, as_non_blocking: TcpServerConnection.wrap(
                self.mock_server_conn.return_value, x, y, as_non_blocking=as_non_blocking,
            )

        type(self.mock_server_conn.return_value).connection = \
            mock.PropertyMock(side_effect=mock_connection)

        type(self.mock_server_conn.return_value).closed = \
            mock.PropertyMock(return_value=False)

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
            threaded=True,
        )
        self.assertTrue(tls_interception_enabled(self.flags))
        # In this test we enable a mock http protocol handler plugin
        # and a mock http proxy plugin.  Internally, http protocol
        # handler will only initialize proxy plugin as we'll never
        # make any other request.
        self.plugin = mock.MagicMock()
        self.proxy_plugin = mock.MagicMock()
        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin],
            b'HttpProxyBasePlugin': [self.proxy_plugin],
        }
        self._conn = self.mock_fromfd.return_value
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

        self.plugin.assert_not_called()
        self.proxy_plugin.assert_not_called()

        # Mock a CONNECT request followed by a GET request
        # from client connection
        headers = {
            b'Host': bytes_(netloc),
        }
        connect_request = build_http_request(
            httpMethods.CONNECT,
            bytes_(netloc),
            headers=headers,
            no_ua=True,
        )
        self._conn.recv.return_value = connect_request
        get_request = build_http_request(
            httpMethods.GET,
            b'/',
            headers=headers,
        )
        client_tls_sock.recv.return_value = get_request

        T = TypeVar('T')  # noqa: N806

        async def asyncReturn(val: T) -> T:
            return val

        # Prepare mocked HttpProxyBasePlugin
        # 1. Mock descriptor mixin methods
        #
        # NOTE: We need multiple async result otherwise
        # we will end up with cannot await on already
        # awaited coroutine.
        self.proxy_plugin.return_value.get_descriptors.side_effect = \
            [asyncReturn(([], [])), asyncReturn(([], []))]
        self.proxy_plugin.return_value.write_to_descriptors.side_effect = \
            [asyncReturn(False), asyncReturn(False)]
        self.proxy_plugin.return_value.read_from_descriptors.side_effect = \
            [asyncReturn(False), asyncReturn(False)]
        # 2. Mock plugin lifecycle methods
        self.proxy_plugin.return_value.resolve_dns.return_value = None, None
        self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r
        self.proxy_plugin.return_value.handle_client_data.side_effect = lambda r: r
        self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r
        self.proxy_plugin.return_value.handle_upstream_chunk.side_effect = lambda r: r
        self.proxy_plugin.return_value.on_upstream_connection_close.return_value = None
        self.proxy_plugin.return_value.on_access_log.side_effect = lambda r: r
        self.proxy_plugin.return_value.do_intercept.return_value = True

        self.mock_selector.return_value.select.side_effect = [
            # Trigger read on plain socket
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            # Trigger read on encrypted socket
            [(
                selectors.SelectorKey(
                    fileobj=client_tls_sock.fileno(),
                    fd=client_tls_sock.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        await self.protocol_handler._run_once()

        # Assert correct plugin was initialized
        self.plugin.assert_not_called()
        self.proxy_plugin.assert_called_once()
        self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags)
        # Actual call arg must be `_conn` object
        # but because internally the reference is updated
        # we assert it against `mock_ssl_wrap` which is
        # called during proxy plugin initialization
        # for interception
        self.assertEqual(
            self.proxy_plugin.call_args[0][2].connection,
            client_tls_sock,
        )

        # Invoked lifecycle callbacks
        self.proxy_plugin.return_value.resolve_dns.assert_called_once_with(
            host,
            port,
        )
        self.proxy_plugin.return_value.before_upstream_connection.assert_called(
        )
        self.proxy_plugin.return_value.handle_client_request.assert_called_once(
        )
        self.proxy_plugin.return_value.do_intercept.assert_called_once()
        # All the invoked lifecycle callbacks will receive the CONNECT request
        # packet with / as the path
        callback_request: HttpParser = \
            self.proxy_plugin.return_value.before_upstream_connection.call_args_list[0][0][0]
        callback_request1: HttpParser = \
            self.proxy_plugin.return_value.handle_client_request.call_args_list[0][0][0]
        callback_request2: HttpParser = \
            self.proxy_plugin.return_value.do_intercept.call_args_list[0][0][0]
        self.assertEqual(callback_request.host, bytes_(host))
        self.assertEqual(callback_request.port, 443)
        self.assertEqual(callback_request.header(b'Host'), headers[b'Host'])
        assert callback_request._url
        self.assertEqual(callback_request._url.remainder, None)
        self.assertEqual(callback_request.method, httpMethods.CONNECT)
        self.assertEqual(callback_request.is_https_tunnel, True)
        self.assertEqual(callback_request.build(), callback_request1.build())
        self.assertEqual(callback_request.build(), callback_request2.build())
        # Lifecycle callbacks not invoked
        self.proxy_plugin.return_value.handle_client_data.assert_not_called()
        self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called(
        )
        self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called(
        )
        self.proxy_plugin.return_value.on_access_log.assert_not_called()

        self.mock_server_conn.assert_called_with(host, port)
        self.mock_server_conn.return_value.connection.setblocking.assert_called_with(
            False, )

        self.mock_ssl_context.assert_called_with(
            ssl.Purpose.SERVER_AUTH,
            cafile=str(DEFAULT_CA_FILE),
        )
        self.assertEqual(plain_connection.setblocking.call_count, 2)
        self.mock_ssl_context.return_value.wrap_socket.assert_called_with(
            plain_connection,
            server_hostname=host,
        )
        self.assertEqual(self.mock_sign_csr.call_count, 1)
        self.assertEqual(self.mock_gen_csr.call_count, 1)
        self.assertEqual(self.mock_gen_public_key.call_count, 1)
        self.assertEqual(upstream_tls_sock.setblocking.call_count, 1)
        self.assertEqual(
            self.mock_server_conn.return_value._conn,
            upstream_tls_sock,
        )
        self._conn.send.assert_called_with(
            PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, )
        assert self.flags.ca_cert_dir is not None
        self.mock_ssl_wrap.assert_called_with(
            self._conn,
            server_side=True,
            keyfile=self.flags.ca_signing_key_file,
            certfile=HttpProxyPlugin.generated_cert_file_path(
                self.flags.ca_cert_dir,
                host,
            ),
            ssl_version=ssl.PROTOCOL_TLS,
        )
        self.assertEqual(self._conn.setblocking.call_count, 2)
        self.assertEqual(
            self.protocol_handler.work.connection,
            client_tls_sock,
        )

        # Assert connection references for all other plugins is updated
        self.assertEqual(
            self.proxy_plugin.return_value.client._conn,
            client_tls_sock,
        )

        # Now process the GET request
        await self.protocol_handler._run_once()
        self.plugin.assert_not_called()
        self.proxy_plugin.assert_called_once()

        # Lifecycle callbacks still not invoked
        self.proxy_plugin.return_value.handle_client_data.assert_not_called()
        self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called(
        )
        self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called(
        )
        self.proxy_plugin.return_value.on_access_log.assert_not_called()
        # Only handle client request lifecycle must be called again
        self.proxy_plugin.return_value.resolve_dns.assert_called_once_with(
            host,
            port,
        )
        self.proxy_plugin.return_value.before_upstream_connection.assert_called(
        )
        self.assertEqual(
            self.proxy_plugin.return_value.handle_client_request.call_count,
            2,
        )
        self.proxy_plugin.return_value.do_intercept.assert_called_once()

        callback_request = \
            self.proxy_plugin.return_value.handle_client_request.call_args_list[1][0][0]
        self.assertEqual(callback_request.host, None)
        self.assertEqual(callback_request.port, 80)
        self.assertEqual(callback_request.header(b'Host'), headers[b'Host'])
        assert callback_request._url
        self.assertEqual(callback_request._url.remainder, b'/')
        self.assertEqual(callback_request.method, httpMethods.GET)
        self.assertEqual(callback_request.is_https_tunnel, False)
    def _setUp(self, request: Any, mocker: MockerFixture) -> None:
        self.mock_fromfd = mocker.patch('socket.fromfd')
        self.mock_selector = mocker.patch('selectors.DefaultSelector')
        self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr')
        self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr')
        self.mock_gen_public_key = mocker.patch(
            'proxy.http.proxy.server.gen_public_key', )
        self.mock_server_conn = mocker.patch(
            'proxy.http.proxy.server.TcpServerConnection', )
        self.mock_ssl_context = mocker.patch('ssl.create_default_context')
        self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket')

        self.mock_sign_csr.return_value = True
        self.mock_gen_csr.return_value = True
        self.mock_gen_public_key.return_value = True

        self.fileno = 10
        self._addr = ('127.0.0.1', 54382)
        self.flags = FlagParser.initialize(
            ca_cert_file='ca-cert.pem',
            ca_key_file='ca-key.pem',
            ca_signing_key_file='ca-signing-key.pem',
            threaded=True,
        )
        self.plugin = mocker.MagicMock()

        plugin = get_plugin_by_test_name(request.param)

        self.flags.plugins = {
            b'HttpProtocolHandlerPlugin': [HttpProxyPlugin],
            b'HttpProxyBasePlugin': [plugin],
        }
        self._conn = mocker.MagicMock(spec=socket.socket)
        self.mock_fromfd.return_value = self._conn
        self.protocol_handler = HttpProtocolHandler(
            HttpClientConnection(self._conn, self._addr),
            flags=self.flags,
        )
        self.protocol_handler.initialize()

        self.server = self.mock_server_conn.return_value

        self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection
        self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket)
        self.mock_ssl_wrap.return_value = self.client_ssl_connection

        def has_buffer() -> bool:
            return cast(bool, self.server.queue.called)

        def closed() -> bool:
            return not self.server.connect.called

        def mock_connection() -> Any:
            if self.mock_ssl_context.return_value.wrap_socket.called:
                return self.server_ssl_connection
            return self._conn

        # Do not mock the original wrap method
        self.server.wrap.side_effect = \
            lambda x, y, as_non_blocking: TcpServerConnection.wrap(
                self.server, x, y, as_non_blocking=as_non_blocking,
            )

        self.server.has_buffer.side_effect = has_buffer
        type(self.server).closed = mocker.PropertyMock(side_effect=closed)
        type(self.server, ).connection = mocker.PropertyMock(
            side_effect=mock_connection, )

        self.mock_selector.return_value.select.side_effect = [
            [(
                selectors.SelectorKey(
                    fileobj=self._conn.fileno(),
                    fd=self._conn.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.client_ssl_connection.fileno(),
                    fd=self.client_ssl_connection.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.server_ssl_connection.fileno(),
                    fd=self.server_ssl_connection.fileno(),
                    events=selectors.EVENT_WRITE,
                    data=None,
                ),
                selectors.EVENT_WRITE,
            )],
            [(
                selectors.SelectorKey(
                    fileobj=self.server_ssl_connection.fileno(),
                    fd=self.server_ssl_connection.fileno(),
                    events=selectors.EVENT_READ,
                    data=None,
                ),
                selectors.EVENT_READ,
            )],
        ]

        # Connect
        def send(raw: bytes) -> int:
            return len(raw)

        self._conn.send.side_effect = send
        self._conn.recv.return_value = build_http_request(
            httpMethods.CONNECT,
            b'uni.corn:443',
            no_ua=True,
        )