def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) adblock_json_path = Path( __file__, ).parent.parent.parent / "proxy" / "plugin" / "adblock.json" self.flags = FlagParser.initialize( input_args=[ "--filtered-url-regex-config", str(adblock_json_path), ], threaded=True, ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize()
def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value # Setup a static directory self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static') self.index_file_path = os.path.join( self.static_server_dir, 'index.html', ) self.html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>''' os.makedirs(self.static_server_dir, exist_ok=True) with open(self.index_file_path, 'wb') as f: f.write(self.html_file_content) # flags = FlagParser.initialize( enable_static_server=True, static_server_dir=self.static_server_dir, threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize()
async def test_proxy_authentication_failed(self) -> None: self._conn = self.mock_fromfd.return_value mock_selector_for_client_read(self) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], PROXY_AUTH_FAILED_RESPONSE_PKT, )
def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.pac_file = request.param if isinstance(self.pac_file, str): with open(self.pac_file, 'rb') as f: self.expected_response = f.read() else: self.expected_response = PAC_FILE_CONTENT self.flags = FlagParser.initialize( pac_file=self.pac_file, threaded=True, ) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PAC_FILE), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ]) mock_selector_for_client_read(self)
class TestWebServerPlugin(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.flags = FlagParser.initialize(threaded=True) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] async def test_default_web_server_returns_404(self) -> None: self._conn = self.mock_fromfd.return_value self.mock_selector.return_value.select.return_value = [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ] flags = FlagParser.initialize(threaded=True) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) self.assertEqual( self.protocol_handler.work.buffer[0], NOT_FOUND_RESPONSE_PKT, )
class TestWebServerPluginWithPacFilePlugin(Assertions): @pytest.fixture( autouse=True, params=[ PAC_FILE_PATH, PAC_FILE_CONTENT, ], ) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.pac_file = request.param if isinstance(self.pac_file, str): with open(self.pac_file, 'rb') as f: self.expected_response = f.read() else: self.expected_response = PAC_FILE_CONTENT self.flags = FlagParser.initialize( pac_file=self.pac_file, threaded=True, ) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PAC_FILE), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ]) mock_selector_for_client_read(self) @pytest.mark.asyncio # type: ignore[misc] async def test_pac_file_served_from_disk(self) -> None: await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) self._conn.send.called_once_with( build_http_response( 200, reason=b'OK', headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', }, body=self.expected_response, conn_close=True, ), )
async def test_authenticated_proxy_http_get(self) -> None: self._conn = self.mock_fromfd.return_value mock_selector_for_client_read(self) server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.INITIALIZED, ) self._conn.recv.return_value = CRLF await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.LINE_RCVD, ) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', httpHeaders.PROXY_CONNECTION + b': Keep-Alive', httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz', CRLF, ]) await self.assert_data_queued(server)
def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.flags = FlagParser.initialize(threaded=True) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize()
def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ["--basic-auth", "user:pass"], threaded=True, ) self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize()
def _setUp(self, mocker: MockerFixture) -> None: self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_fromfd = mocker.patch('socket.fromfd') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize(threaded=True) self.plugin = mocker.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize()
async def test_authenticated_proxy_http_tunnel(self) -> None: server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = self.mock_fromfd.return_value self.mock_selector_for_client_read_and_server_write(server) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), httpHeaders.PROXY_CONNECTION + b': Keep-Alive', httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz', CRLF, ]) await self.assert_tunnel_response(server) self.protocol_handler.work.flush() await self.assert_data_queued_to_server(server) await self.protocol_handler._run_once() server.flush.assert_called_once()
async def test_default_web_server_returns_404(self) -> None: self._conn = self.mock_fromfd.return_value self.mock_selector.return_value.select.return_value = [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ] flags = FlagParser.initialize(threaded=True) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) self.assertEqual( self.protocol_handler.work.buffer[0], NOT_FOUND_RESPONSE_PKT, )
def test_on_client_connection_called_on_teardown(mocker: MockerFixture) -> None: plugin = mocker.MagicMock() mock_fromfd = mocker.patch('socket.fromfd') flags = FlagParser.initialize(threaded=True) flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} _conn = mock_fromfd.return_value _addr = ('127.0.0.1', 54382) protocol_handler = HttpProtocolHandler( HttpClientConnection(_conn, _addr), flags=flags, ) protocol_handler.initialize() plugin.assert_not_called() mock_run_once = mocker.patch.object(protocol_handler, '_run_once') mock_run_once.return_value = True protocol_handler.run() assert _conn.closed
class TestHttpProxyPluginExamples(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) adblock_json_path = Path( __file__, ).parent.parent.parent / "proxy" / "plugin" / "adblock.json" self.flags = FlagParser.initialize( input_args=[ "--filtered-url-regex-config", str(adblock_json_path), ], threaded=True, ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_modify_post_data_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_modify_post_data_plugin(self) -> None: original = b'{"key": "value"}' modified = b'{"key": "modified"}' self._conn.recv.return_value = build_http_request( b'POST', b'http://httpbin.org/post', headers={ b'Host': b'httpbin.org', b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), }, body=original, no_ua=True, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_with( 'httpbin.org', DEFAULT_HTTP_PORT, ) self.mock_server_conn.return_value.queue.assert_called_with( build_http_request( b'POST', b'/post', headers={ b'Host': b'httpbin.org', b'Content-Type': b'application/json', b'Content-Length': bytes_(len(modified)), b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, body=modified, no_ua=True, ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_proposed_rest_api_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_proposed_rest_api_plugin(self) -> None: path = b'/v1/users/' self._conn.recv.return_value = build_http_request( b'GET', b'http://%s%s' % ( ProposedRestApiPlugin.API_SERVER, path, ), headers={ b'Host': ProposedRestApiPlugin.API_SERVER, }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_not_called() response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( response.header(b'content-type'), b'application/json', ) self.assertEqual( response.header(b'content-encoding'), b'gzip', ) self.assertEqual( gzip.decompress(response.body), bytes_(json.dumps(ProposedRestApiPlugin.REST_API_SPEC[path], ), ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_redirect_to_custom_server_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_redirect_to_custom_server_plugin(self) -> None: request = build_http_request( b'GET', b'http://example.org/get', headers={ b'Host': b'example.org', }, no_ua=True, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() upstream = urlparse.urlsplit( RedirectToCustomServerPlugin.UPSTREAM_SERVER, ) self.mock_server_conn.assert_called_with('localhost', 8899) self.mock_server_conn.return_value.queue.assert_called_with( build_http_request( b'GET', upstream.path, headers={ b'Host': upstream.netloc, b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, no_ua=True, ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_redirect_to_custom_server_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_redirect_to_custom_server_plugin_skips_https(self) -> None: request = build_http_request( b'CONNECT', b'jaxl.com:443', headers={ b'Host': b'jaxl.com:443', }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_with('jaxl.com', 443) self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_filter_by_upstream_host_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_filter_by_upstream_host_plugin(self) -> None: request = build_http_request( b'GET', b'http://facebook.com/', headers={ b'Host': b'facebook.com', }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_not_called() self.assertEqual( self.protocol_handler.work.buffer[0], build_http_response( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', conn_close=True, ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_man_in_the_middle_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_man_in_the_middle_plugin(self) -> None: request = build_http_request( b'GET', b'http://super.secure/', headers={ b'Host': b'super.secure', }, no_ua=True, ) self._conn.recv.return_value = request server = self.mock_server_conn.return_value server.connect.return_value = True def has_buffer() -> bool: return cast(bool, server.queue.called) def closed() -> bool: return not server.connect.called server.has_buffer.side_effect = has_buffer type(server).closed = mock.PropertyMock(side_effect=closed) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=server.connection.fileno(), fd=server.connection.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], [( selectors.SelectorKey( fileobj=server.connection.fileno(), fd=server.connection.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] # Client read await self.protocol_handler._run_once() self.mock_server_conn.assert_called_with( 'super.secure', DEFAULT_HTTP_PORT, ) server.connect.assert_called_once() queued_request = \ build_http_request( b'GET', b'/', headers={ b'Host': b'super.secure', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, no_ua=True, ) server.queue.assert_called_once() print(server.queue.call_args_list[0][0][0].tobytes()) print(queued_request) self.assertEqual(server.queue.call_args_list[0][0][0], queued_request) # Server write await self.protocol_handler._run_once() server.flush.assert_called_once() # Server read server.recv.return_value = \ build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Original Response From Upstream', ) await self.protocol_handler._run_once() response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( gzip.decompress(response.body), b'Hello from man in the middle', ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_filter_by_url_regex_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_filter_by_url_regex_plugin(self) -> None: request = build_http_request( b'GET', b'http://www.facebook.com/tr/', headers={ b'Host': b'www.facebook.com', }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], build_http_response( status_code=httpStatusCodes.NOT_FOUND, reason=b'Blocked', conn_close=True, ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_shortlink_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_shortlink_plugin(self) -> None: request = build_http_request( b'GET', b'http://t/', headers={ b'Host': b't', }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), build_http_response( status_code=httpStatusCodes.SEE_OTHER, reason=b'See Other', headers={ b'Location': b'http://twitter.com/', }, conn_close=True, ), ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_shortlink_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_shortlink_plugin_unknown(self) -> None: request = build_http_request( b'GET', b'http://unknown/', headers={ b'Host': b'unknown', }, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0].tobytes(), NOT_FOUND_RESPONSE_PKT, ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( "_setUp", (('test_shortlink_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_shortlink_plugin_external(self) -> None: request = build_http_request( b'GET', b'http://jaxl.com/', headers={ b'Host': b'jaxl.com', }, no_ua=True, ) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_once_with('jaxl.com', 80) self.mock_server_conn.return_value.queue.assert_called_with( build_http_request( b'GET', b'/', headers={ b'Host': b'jaxl.com', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, no_ua=True, ), ) self.assertFalse(self.protocol_handler.work.has_buffer())
class TestHttpProxyAuthFailed(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ["--basic-auth", "user:pass"], threaded=True, ) self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_auth_fails_without_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( self.protocol_handler.work.buffer[0], PROXY_AUTH_FAILED_RESPONSE_PKT, ) self._conn.send.assert_not_called() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_auth_fails_with_invalid_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', httpHeaders.PROXY_AUTHORIZATION: b'Basic hello', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_not_called() self.assertEqual(self.protocol_handler.work.has_buffer(), True) self.assertEqual( self.protocol_handler.work.buffer[0], PROXY_AUTH_FAILED_RESPONSE_PKT, ) self._conn.send.assert_not_called() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_auth_works_with_valid_cred(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', httpHeaders.PROXY_AUTHORIZATION: b'Basic dXNlcjpwYXNz', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_once() self.assertEqual(self.protocol_handler.work.has_buffer(), False) @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_auth_works_with_mixed_case_basic_string(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', httpHeaders.PROXY_AUTHORIZATION: b'bAsIc dXNlcjpwYXNz', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_once() self.assertEqual(self.protocol_handler.work.has_buffer(), False)
class TestStaticWebServerPlugin(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value # Setup a static directory self.static_server_dir = os.path.join(tempfile.gettempdir(), 'static') self.index_file_path = os.path.join( self.static_server_dir, 'index.html', ) self.html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>''' os.makedirs(self.static_server_dir, exist_ok=True) with open(self.index_file_path, 'wb') as f: f.write(self.html_file_content) # flags = FlagParser.initialize( enable_static_server=True, static_server_dir=self.static_server_dir, threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] async def test_static_web_server_serves(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'/index.html', ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], ] await self.protocol_handler._run_once() await self.protocol_handler._run_once() self.assertEqual(self.mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) encoded_html_file_content = gzip.compress(self.html_file_content) # parse response and verify response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self._conn.send.call_args[0][0]) self.assertEqual(response.code, b'200') self.assertEqual(response.header(b'content-type'), b'text/html') self.assertEqual(response.header(b'cache-control'), b'max-age=86400') self.assertEqual(response.header(b'content-encoding'), b'gzip') self.assertEqual(response.header(b'connection'), b'close') self.assertEqual( response.header(b'content-length'), bytes_(len(encoded_html_file_content)), ) assert response.body self.assertEqual( gzip.decompress(response.body), self.html_file_content, ) @pytest.mark.asyncio # type: ignore[misc] async def test_static_web_server_serves_404(self) -> None: self._conn.recv.return_value = build_http_request( b'GET', b'/not-found.html', ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], ] await self.protocol_handler._run_once() await self.protocol_handler._run_once() self.assertEqual(self.mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) self.assertEqual( self._conn.send.call_args[0][0], NOT_FOUND_RESPONSE_PKT, )
class TestHttpProtocolHandlerWithoutServerMock(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.http_server_port = 65535 self.flags = FlagParser.initialize(threaded=True) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_connection_failed(self) -> None: mock_selector_for_client_read(self) self._conn.recv.return_value = CRLF.join([ b'GET http://unknown.domain HTTP/1.1', b'Host: unknown.domain', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], BAD_GATEWAY_RESPONSE_PKT, ) @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_authentication_failed(self) -> None: self._conn = self.mock_fromfd.return_value mock_selector_for_client_read(self) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], PROXY_AUTH_FAILED_RESPONSE_PKT, ) @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_bails_out_for_unknown_schemes(self) -> None: mock_selector_for_client_read(self) self._conn.recv.return_value = CRLF.join([ b'REQMOD icap://icap-server.net/server?arg=87 ICAP/1.0', b'Host: icap-server.net', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], BAD_REQUEST_RESPONSE_PKT, ) @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_bails_out_for_sip_request_lines(self) -> None: mock_selector_for_client_read(self) self._conn.recv.return_value = CRLF.join([ b'OPTIONS sip:nm SIP/2.0', b'Accept: application/sdp', CRLF, ]) await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.work.buffer[0], BAD_REQUEST_RESPONSE_PKT, )
class TestHttpProtocolHandler(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_server_connection = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = self.mock_fromfd.return_value self.http_server_port = 65535 self.flags = FlagParser.initialize(threaded=True) self.flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() @pytest.mark.asyncio # type: ignore[misc] async def test_http_get(self) -> None: server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self.mock_selector_for_client_read_and_server_write(server) # Send request line assert self.http_server_port is not None self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) + CRLF await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.LINE_RCVD, ) self.assertNotEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) # Send headers and blank line, thus completing HTTP request assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', CRLF, ]) await self.assert_data_queued(server) await self.protocol_handler._run_once() server.flush.assert_called_once() async def assert_tunnel_response( self, server: mock.Mock, ) -> None: await self.protocol_handler._run_once() self.assertTrue( cast( HttpProxyPlugin, self.protocol_handler.plugin, ).upstream is not None, ) self.assertEqual( self.protocol_handler.work.buffer[0], PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) self.mock_server_connection.assert_called_once() server.connect.assert_called_once() server.queue.assert_not_called() server.closed = False parser = HttpParser(httpParserTypes.RESPONSE_PARSER) parser.parse(self.protocol_handler.work.buffer[0]) self.assertEqual(parser.state, httpParserStates.COMPLETE) assert parser.code is not None self.assertEqual(int(parser.code), 200) @pytest.mark.asyncio # type: ignore[misc] async def test_http_tunnel(self) -> None: server = self.mock_server_connection.return_value server.connect.return_value = True def has_buffer() -> bool: return cast(bool, server.queue.called) server.has_buffer.side_effect = has_buffer self.mock_selector.return_value.select.side_effect = [ [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ], [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=0, data=None, ), selectors.EVENT_WRITE, ), ], [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ], [ ( selectors.SelectorKey( fileobj=server.connection.fileno(), fd=server.connection.fileno(), events=0, data=None, ), selectors.EVENT_WRITE, ), ], ] assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', CRLF, ]) await self.assert_tunnel_response(server) # Dispatch tunnel established response to client await self.protocol_handler._run_once() await self.assert_data_queued_to_server(server) await self.protocol_handler._run_once() self.assertEqual(server.queue.call_count, 1) server.flush.assert_called_once() @pytest.mark.asyncio # type: ignore[misc] async def test_authenticated_proxy_http_get(self) -> None: self._conn = self.mock_fromfd.return_value mock_selector_for_client_read(self) server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.INITIALIZED, ) self._conn.recv.return_value = CRLF await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.LINE_RCVD, ) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', httpHeaders.PROXY_CONNECTION + b': Keep-Alive', httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz', CRLF, ]) await self.assert_data_queued(server) @pytest.mark.asyncio # type: ignore[misc] async def test_authenticated_proxy_http_tunnel(self) -> None: server = self.mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = self.mock_fromfd.return_value self.mock_selector_for_client_read_and_server_write(server) flags = FlagParser.initialize( auth_code=base64.b64encode(b'user:pass'), threaded=True, ) flags.plugins = Plugins.load([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=flags, ) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), httpHeaders.PROXY_CONNECTION + b': Keep-Alive', httpHeaders.PROXY_AUTHORIZATION + b': Basic dXNlcjpwYXNz', CRLF, ]) await self.assert_tunnel_response(server) self.protocol_handler.work.flush() await self.assert_data_queued_to_server(server) await self.protocol_handler._run_once() server.flush.assert_called_once() def mock_selector_for_client_read_and_server_write( self, server: mock.Mock, ) -> None: self.mock_selector.return_value.select.side_effect = [ [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, ), ], [ ( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=0, data=None, ), selectors.EVENT_READ, ), ], [ ( selectors.SelectorKey( fileobj=server.connection.fileno(), fd=server.connection.fileno(), events=0, data=None, ), selectors.EVENT_WRITE, ), ], ] async def assert_data_queued( self, server: mock.Mock, ) -> None: await self.protocol_handler._run_once() self.assertEqual( self.protocol_handler.request.state, httpParserStates.COMPLETE, ) self.mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False assert self.http_server_port is not None pkt = CRLF.join([ b'GET / HTTP/1.1', b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Via: 1.1 proxy.py v%s' % bytes_(__version__), CRLF, ]) server.queue.assert_called_once() self.assertEqual(server.queue.call_args_list[0][0][0], pkt) server.buffer_size.return_value = len(pkt) async def assert_data_queued_to_server(self, server: mock.Mock) -> None: assert self.http_server_port is not None self.assertEqual( self._conn.send.call_args[0][0], PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) pkt = CRLF.join([ b'GET / HTTP/1.1', b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), CRLF, ]) self._conn.recv.return_value = pkt await self.protocol_handler._run_once() server.queue.assert_called_once_with(pkt) server.buffer_size.return_value = len(pkt) server.flush.assert_not_called()
async def test_e2e(self, mocker: MockerFixture) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') self.mock_gen_public_key = mocker.patch( 'proxy.http.proxy.server.gen_public_key', ) self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_ssl_context = mocker.patch('ssl.create_default_context') self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return ssl_connection return plain_connection # Do not mock the original wrap method self.mock_server_conn.return_value.wrap.side_effect = \ lambda x, y: TcpServerConnection.wrap( self.mock_server_conn.return_value, x, y, ) type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', threaded=True, ) self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self.plugin.assert_not_called() self.proxy_plugin.assert_not_called() connect_request = build_http_request( httpMethods.CONNECT, bytes_(netloc), headers={ b'Host': bytes_(netloc), }, ) self._conn.recv.return_value = connect_request async def asyncReturnBool(val: bool) -> bool: return val # Prepare mocked HttpProtocolHandlerPlugin # self.plugin.return_value.get_descriptors.return_value = ([], []) # self.plugin.return_value.write_to_descriptors.return_value = asyncReturnBool(False) # self.plugin.return_value.read_from_descriptors.return_value = asyncReturnBool(False) # self.plugin.return_value.on_client_data.side_effect = lambda raw: raw # self.plugin.return_value.on_request_complete.return_value = False # self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk # self.plugin.return_value.on_client_connection_close.return_value = None # Prepare mocked HttpProxyBasePlugin self.proxy_plugin.return_value.write_to_descriptors.return_value = \ asyncReturnBool(False) self.proxy_plugin.return_value.read_from_descriptors.return_value = \ asyncReturnBool(False) self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r self.proxy_plugin.return_value.resolve_dns.return_value = None, None self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() # Assert correct plugin was initialized self.plugin.assert_not_called() self.proxy_plugin.assert_called_once() self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags) # Actual call arg must be `_conn` object # but because internally the reference is updated # we assert it against `mock_ssl_wrap` which is # called during proxy plugin initialization # for interception self.assertEqual( self.proxy_plugin.call_args[0][2].connection, self.mock_ssl_wrap.return_value, ) # Assert our mocked plugins invocations # self.plugin.return_value.get_descriptors.assert_called() # self.plugin.return_value.write_to_descriptors.assert_called_with([]) # # on_client_data is only called after initial request has completed # self.plugin.return_value.on_client_data.assert_not_called() # self.plugin.return_value.on_request_complete.assert_called() # self.plugin.return_value.read_from_descriptors.assert_called_with([ # self._conn.fileno(), # ]) self.proxy_plugin.return_value.before_upstream_connection.assert_called( ) self.proxy_plugin.return_value.handle_client_request.assert_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( False, ) self.mock_ssl_context.assert_called_with( ssl.Purpose.SERVER_AUTH, cafile=str(DEFAULT_CA_FILE), ) # self.assertEqual(self.mock_ssl_context.return_value.options, # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host, ) self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.assertEqual(ssl_connection.setblocking.call_count, 1) self.assertEqual( self.mock_server_conn.return_value._conn, ssl_connection, ) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( self.flags.ca_cert_dir, host, ), ssl_version=ssl.PROTOCOL_TLS, ) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( self.protocol_handler.work.connection, self.mock_ssl_wrap.return_value, ) # Assert connection references for all other plugins is updated # self.assertEqual( # self.plugin.return_value.client._conn, # self.mock_ssl_wrap.return_value, # ) self.assertEqual( self.proxy_plugin.return_value.client._conn, self.mock_ssl_wrap.return_value, )
class TestHttpProxyPlugin: @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, mocker: MockerFixture) -> None: self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_fromfd = mocker.patch('socket.fromfd') self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize(threaded=True) self.plugin = mocker.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() def test_proxy_plugin_not_initialized_unless_first_request_completes(self) -> None: self.plugin.assert_not_called() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_plugin_on_and_before_upstream_connection(self) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.plugin.return_value.handle_client_request.side_effect = lambda r: r self.plugin.return_value.resolve_dns.return_value = None, None self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_called_with( 'upstream.host', DEFAULT_HTTP_PORT, ) self.plugin.return_value.before_upstream_connection.assert_called() self.plugin.return_value.handle_client_request.assert_called() @pytest.mark.asyncio # type: ignore[misc] async def test_proxy_plugin_before_upstream_connection_can_teardown(self) -> None: self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException() self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host', }, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() self.mock_server_conn.assert_not_called() self.plugin.return_value.before_upstream_connection.assert_called() def test_proxy_plugin_plugins_can_teardown_from_write_to_descriptors(self) -> None: pass def test_proxy_plugin_retries_on_ssl_want_write_error(self) -> None: pass def test_proxy_plugin_broken_pipe_error_on_write_will_teardown(self) -> None: pass def test_proxy_plugin_plugins_can_teardown_from_read_from_descriptors(self) -> None: pass def test_proxy_plugin_retries_on_ssl_want_read_error(self) -> None: pass def test_proxy_plugin_timeout_error_on_read_will_teardown(self) -> None: pass def test_proxy_plugin_invokes_handle_pipeline_response(self) -> None: pass def test_proxy_plugin_invokes_on_access_log(self) -> None: pass def test_proxy_plugin_skips_server_teardown_when_client_closes_and_server_never_initialized(self) -> None: pass def test_proxy_plugin_invokes_handle_client_data(self) -> None: pass def test_proxy_plugin_handles_pipeline_response(self) -> None: pass def test_proxy_plugin_invokes_resolve_dns(self) -> None: pass def test_proxy_plugin_require_both_host_port_to_connect(self) -> None: pass
class TestHttpProxyPluginExamplesWithTlsInterception(Assertions): @pytest.fixture(autouse=True) # type: ignore[misc] def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') self.mock_gen_public_key = mocker.patch( 'proxy.http.proxy.server.gen_public_key', ) self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_ssl_context = mocker.patch('ssl.create_default_context') self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', threaded=True, ) self.plugin = mocker.MagicMock() plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mocker.MagicMock(spec=socket.socket) self.mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called) def closed() -> bool: return not self.server.connect.called def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return self.server_ssl_connection return self._conn # Do not mock the original wrap method self.server.wrap.side_effect = \ lambda x, y, as_non_blocking: TcpServerConnection.wrap( self.server, x, y, as_non_blocking=as_non_blocking, ) self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mocker.PropertyMock(side_effect=closed) type(self.server, ).connection = mocker.PropertyMock( side_effect=mock_connection, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self.client_ssl_connection.fileno(), fd=self.client_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self.server_ssl_connection.fileno(), fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], [( selectors.SelectorKey( fileobj=self.server_ssl_connection.fileno(), fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] # Connect def send(raw: bytes) -> int: return len(raw) self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443', no_ua=True, ) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( '_setUp', (('test_modify_post_data_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_modify_post_data_plugin(self) -> None: await self.protocol_handler._run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual( self.protocol_handler.work.connection, self.client_ssl_connection, ) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) self.assertFalse(self.protocol_handler.work.has_buffer()) # original = b'{"key": "value"}' modified = b'{"key": "modified"}' self.client_ssl_connection.recv.return_value = build_http_request( b'POST', b'/', headers={ b'Host': b'uni.corn', b'Content-Length': bytes_(len(original)), b'Content-Type': b'application/x-www-form-urlencoded', }, body=original, no_ua=True, ) await self.protocol_handler._run_once() self.server.queue.assert_called_once() # pkt = build_http_request( # b'POST', b'/', # headers={ # b'Host': b'uni.corn', # b'Content-Length': bytes_(len(modified)), # b'Content-Type': b'application/json', # }, # body=modified, # ) response = HttpParser.response( self.server.queue.call_args_list[0][0][0].tobytes(), ) self.assertEqual(response.body, modified) @pytest.mark.asyncio # type: ignore[misc] @pytest.mark.parametrize( '_setUp', (('test_man_in_the_middle_plugin'), ), indirect=True, ) # type: ignore[misc] async def test_man_in_the_middle_plugin(self) -> None: await self.protocol_handler._run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual( self.protocol_handler.work.connection, self.client_ssl_connection, ) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) self.assertFalse(self.protocol_handler.work.has_buffer()) # request = build_http_request( b'GET', b'/', headers={ b'Host': b'uni.corn', }, no_ua=True, ) self.client_ssl_connection.recv.return_value = request # Client read await self.protocol_handler._run_once() self.server.queue.assert_called_once_with(request) # Server write await self.protocol_handler._run_once() self.server.flush.assert_called_once() # Server read self.server.recv.return_value = okResponse( content=b'Original Response From Upstream', ) await self.protocol_handler._run_once() response = HttpParser(httpParserTypes.RESPONSE_PARSER) response.parse(self.protocol_handler.work.buffer[0]) assert response.body self.assertEqual( gzip.decompress(response.body), b'Hello from man in the middle', )
async def test_e2e(self, mocker: MockerFixture) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') self.mock_gen_public_key = mocker.patch( 'proxy.http.proxy.server.gen_public_key', ) self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True # Used for server side wrapping self.mock_ssl_context = mocker.patch('ssl.create_default_context') upstream_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = upstream_tls_sock # Used for client wrapping self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') client_tls_sock = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = client_tls_sock plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return upstream_tls_sock return plain_connection # Do not mock the original wrap method self.mock_server_conn.return_value.wrap.side_effect = \ lambda x, y, as_non_blocking: TcpServerConnection.wrap( self.mock_server_conn.return_value, x, y, as_non_blocking=as_non_blocking, ) type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) type(self.mock_server_conn.return_value).closed = \ mock.PropertyMock(return_value=False) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', threaded=True, ) self.assertTrue(tls_interception_enabled(self.flags)) # In this test we enable a mock http protocol handler plugin # and a mock http proxy plugin. Internally, http protocol # handler will only initialize proxy plugin as we'll never # make any other request. self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } self._conn = self.mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self.plugin.assert_not_called() self.proxy_plugin.assert_not_called() # Mock a CONNECT request followed by a GET request # from client connection headers = { b'Host': bytes_(netloc), } connect_request = build_http_request( httpMethods.CONNECT, bytes_(netloc), headers=headers, no_ua=True, ) self._conn.recv.return_value = connect_request get_request = build_http_request( httpMethods.GET, b'/', headers=headers, ) client_tls_sock.recv.return_value = get_request T = TypeVar('T') # noqa: N806 async def asyncReturn(val: T) -> T: return val # Prepare mocked HttpProxyBasePlugin # 1. Mock descriptor mixin methods # # NOTE: We need multiple async result otherwise # we will end up with cannot await on already # awaited coroutine. self.proxy_plugin.return_value.get_descriptors.side_effect = \ [asyncReturn(([], [])), asyncReturn(([], []))] self.proxy_plugin.return_value.write_to_descriptors.side_effect = \ [asyncReturn(False), asyncReturn(False)] self.proxy_plugin.return_value.read_from_descriptors.side_effect = \ [asyncReturn(False), asyncReturn(False)] # 2. Mock plugin lifecycle methods self.proxy_plugin.return_value.resolve_dns.return_value = None, None self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_data.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r self.proxy_plugin.return_value.handle_upstream_chunk.side_effect = lambda r: r self.proxy_plugin.return_value.on_upstream_connection_close.return_value = None self.proxy_plugin.return_value.on_access_log.side_effect = lambda r: r self.proxy_plugin.return_value.do_intercept.return_value = True self.mock_selector.return_value.select.side_effect = [ # Trigger read on plain socket [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], # Trigger read on encrypted socket [( selectors.SelectorKey( fileobj=client_tls_sock.fileno(), fd=client_tls_sock.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] await self.protocol_handler._run_once() # Assert correct plugin was initialized self.plugin.assert_not_called() self.proxy_plugin.assert_called_once() self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags) # Actual call arg must be `_conn` object # but because internally the reference is updated # we assert it against `mock_ssl_wrap` which is # called during proxy plugin initialization # for interception self.assertEqual( self.proxy_plugin.call_args[0][2].connection, client_tls_sock, ) # Invoked lifecycle callbacks self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( host, port, ) self.proxy_plugin.return_value.before_upstream_connection.assert_called( ) self.proxy_plugin.return_value.handle_client_request.assert_called_once( ) self.proxy_plugin.return_value.do_intercept.assert_called_once() # All the invoked lifecycle callbacks will receive the CONNECT request # packet with / as the path callback_request: HttpParser = \ self.proxy_plugin.return_value.before_upstream_connection.call_args_list[0][0][0] callback_request1: HttpParser = \ self.proxy_plugin.return_value.handle_client_request.call_args_list[0][0][0] callback_request2: HttpParser = \ self.proxy_plugin.return_value.do_intercept.call_args_list[0][0][0] self.assertEqual(callback_request.host, bytes_(host)) self.assertEqual(callback_request.port, 443) self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) assert callback_request._url self.assertEqual(callback_request._url.remainder, None) self.assertEqual(callback_request.method, httpMethods.CONNECT) self.assertEqual(callback_request.is_https_tunnel, True) self.assertEqual(callback_request.build(), callback_request1.build()) self.assertEqual(callback_request.build(), callback_request2.build()) # Lifecycle callbacks not invoked self.proxy_plugin.return_value.handle_client_data.assert_not_called() self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called( ) self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called( ) self.proxy_plugin.return_value.on_access_log.assert_not_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( False, ) self.mock_ssl_context.assert_called_with( ssl.Purpose.SERVER_AUTH, cafile=str(DEFAULT_CA_FILE), ) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host, ) self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.assertEqual(upstream_tls_sock.setblocking.call_count, 1) self.assertEqual( self.mock_server_conn.return_value._conn, upstream_tls_sock, ) self._conn.send.assert_called_with( PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT, ) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( self.flags.ca_cert_dir, host, ), ssl_version=ssl.PROTOCOL_TLS, ) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual( self.protocol_handler.work.connection, client_tls_sock, ) # Assert connection references for all other plugins is updated self.assertEqual( self.proxy_plugin.return_value.client._conn, client_tls_sock, ) # Now process the GET request await self.protocol_handler._run_once() self.plugin.assert_not_called() self.proxy_plugin.assert_called_once() # Lifecycle callbacks still not invoked self.proxy_plugin.return_value.handle_client_data.assert_not_called() self.proxy_plugin.return_value.handle_upstream_chunk.assert_not_called( ) self.proxy_plugin.return_value.on_upstream_connection_close.assert_not_called( ) self.proxy_plugin.return_value.on_access_log.assert_not_called() # Only handle client request lifecycle must be called again self.proxy_plugin.return_value.resolve_dns.assert_called_once_with( host, port, ) self.proxy_plugin.return_value.before_upstream_connection.assert_called( ) self.assertEqual( self.proxy_plugin.return_value.handle_client_request.call_count, 2, ) self.proxy_plugin.return_value.do_intercept.assert_called_once() callback_request = \ self.proxy_plugin.return_value.handle_client_request.call_args_list[1][0][0] self.assertEqual(callback_request.host, None) self.assertEqual(callback_request.port, 80) self.assertEqual(callback_request.header(b'Host'), headers[b'Host']) assert callback_request._url self.assertEqual(callback_request._url.remainder, b'/') self.assertEqual(callback_request.method, httpMethods.GET) self.assertEqual(callback_request.is_https_tunnel, False)
def _setUp(self, request: Any, mocker: MockerFixture) -> None: self.mock_fromfd = mocker.patch('socket.fromfd') self.mock_selector = mocker.patch('selectors.DefaultSelector') self.mock_sign_csr = mocker.patch('proxy.http.proxy.server.sign_csr') self.mock_gen_csr = mocker.patch('proxy.http.proxy.server.gen_csr') self.mock_gen_public_key = mocker.patch( 'proxy.http.proxy.server.gen_public_key', ) self.mock_server_conn = mocker.patch( 'proxy.http.proxy.server.TcpServerConnection', ) self.mock_ssl_context = mocker.patch('ssl.create_default_context') self.mock_ssl_wrap = mocker.patch('ssl.wrap_socket') self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = FlagParser.initialize( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', threaded=True, ) self.plugin = mocker.MagicMock() plugin = get_plugin_by_test_name(request.param) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mocker.MagicMock(spec=socket.socket) self.mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler( HttpClientConnection(self._conn, self._addr), flags=self.flags, ) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value self.server_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mocker.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called) def closed() -> bool: return not self.server.connect.called def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return self.server_ssl_connection return self._conn # Do not mock the original wrap method self.server.wrap.side_effect = \ lambda x, y, as_non_blocking: TcpServerConnection.wrap( self.server, x, y, as_non_blocking=as_non_blocking, ) self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mocker.PropertyMock(side_effect=closed) type(self.server, ).connection = mocker.PropertyMock( side_effect=mock_connection, ) self.mock_selector.return_value.select.side_effect = [ [( selectors.SelectorKey( fileobj=self._conn.fileno(), fd=self._conn.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self.client_ssl_connection.fileno(), fd=self.client_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], [( selectors.SelectorKey( fileobj=self.server_ssl_connection.fileno(), fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_WRITE, data=None, ), selectors.EVENT_WRITE, )], [( selectors.SelectorKey( fileobj=self.server_ssl_connection.fileno(), fd=self.server_ssl_connection.fileno(), events=selectors.EVENT_READ, data=None, ), selectors.EVENT_READ, )], ] # Connect def send(raw: bytes) -> int: return len(raw) self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443', no_ua=True, )