def test_authenticated_proxy_http_tunnel(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = mock_fromfd.return_value self.mock_selector_for_client_read_read_server_write( mock_selector, server) flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins( [bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER)]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_tunnel_response(mock_server_connection, server) self.protocol_handler.client.flush() self.assert_data_queued_to_server(server) self.protocol_handler.run_once() server.flush.assert_called_once()
def test_default_web_server_returns_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ] flags = Flags() flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', CRLF, ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) self.assertEqual(self.protocol_handler.client.buffer[0], HttpWebServerPlugin.DEFAULT_404_RESPONSE)
def test_static_web_server_serves_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/not-found.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) self.assertEqual(self._conn.send.call_args[0][0], HttpWebServerPlugin.DEFAULT_404_RESPONSE)
def test_static_web_server_serves(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>''' os.makedirs(static_server_dir, exist_ok=True) with open(index_file_path, 'wb') as f: f.write(html_file_content) self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/index.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True, static_server_dir=static_server_dir) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) encoded_html_file_content = gzip.compress(html_file_content) self.assertEqual( self._conn.send.call_args[0][0], build_http_response(200, reason=b'OK', headers={ b'Content-Type': b'text/html', b'Cache-Control': b'max-age=86400', b'Content-Encoding': b'gzip', b'Connection': b'close', b'Content-Length': bytes_(len(encoded_html_file_content)), }, body=encoded_html_file_content))
def init_and_make_pac_file_request(self, pac_file: str) -> None: flags = Flags(pac_file=pac_file) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,' b'proxy.http.server.HttpWebServerPacFilePlugin') self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ])
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value self.mock_selector = mock_selector self.flags = Flags() self.flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize()
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value self.mock_selector = mock_selector self.flags = Proxy.initialize() self.flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize()
def init_and_make_pac_file_request(self, pac_file: str) -> None: flags = Proxy.initialize(pac_file=pac_file) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PAC_FILE), ]) self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ])
def test_on_client_connection_called_on_teardown( self, mock_fromfd: mock.Mock) -> None: flags = Proxy.initialize() plugin = mock.MagicMock() flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=flags) self.protocol_handler.initialize() plugin.assert_called() with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once: mock_run_once.return_value = True self.protocol_handler.run() self.assertTrue(self._conn.closed) plugin.return_value.on_client_connection_close.assert_called()
def test_proxy_authentication_failed(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) flags = Flags(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT)
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Proxy.initialize() self.plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin] } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize()
def test_authenticated_proxy_http_get(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.INITIALIZED) self._conn.recv.return_value = CRLF self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.LINE_RCVD) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_data_queued(mock_server_connection, server)
def test_proxy_authentication_failed(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), ]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT)
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_sign_csr: mock.Mock, mock_gen_csr: mock.Mock, mock_gen_public_key: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_sign_csr = mock_sign_csr self.mock_gen_csr = mock_gen_csr self.mock_gen_public_key = mock_gen_public_key self.mock_server_conn = mock_server_conn self.mock_ssl_context = mock_ssl_context self.mock_ssl_wrap = mock_ssl_wrap self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(self._testMethodName) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mock.MagicMock(spec=socket.socket) mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called) def closed() -> bool: return not self.server.connect.called def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return self.server_ssl_connection return self._conn self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mock.PropertyMock(side_effect=closed) type(self.server).connection = mock.PropertyMock( side_effect=mock_connection) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.client_ssl_connection, fd=self.client_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] # Connect def send(raw: bytes) -> int: return len(raw) self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443') self.protocol_handler.run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual(self.protocol_handler.client.connection, self.client_ssl_connection) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) self.assertFalse(self.protocol_handler.client.has_buffer())
class TestHttpProxyPluginExamples(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags() self.plugin = mock.MagicMock() self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector plugin = get_plugin_by_test_name(self._testMethodName) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_modify_post_data_plugin(self, mock_server_conn: mock.Mock) -> None: original = b'{"key": "value"}' modified = b'{"key": "modified"}' self._conn.recv.return_value = build_http_request( b'POST', b'http://httpbin.org/post', headers={ b'Host': b'httpbin.org', b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), }, body=original) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() mock_server_conn.assert_called_with('httpbin.org', DEFAULT_HTTP_PORT) mock_server_conn.return_value.queue.assert_called_with( build_http_request(b'POST', b'/post', headers={ b'Host': b'httpbin.org', b'Content-Length': bytes_(len(modified)), b'Content-Type': b'application/json', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, }, body=modified)) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proposed_rest_api_plugin(self, mock_server_conn: mock.Mock) -> None: path = b'/v1/users/' self._conn.recv.return_value = build_http_request( b'GET', b'http://%s%s' % (ProposedRestApiPlugin.API_SERVER, path), headers={ b'Host': ProposedRestApiPlugin.API_SERVER, }) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response( httpStatusCodes.OK, reason=b'OK', headers={b'Content-Type': b'application/json'}, body=bytes_( json.dumps(ProposedRestApiPlugin.REST_API_SPEC[path])))) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_redirect_to_custom_server_plugin( self, mock_server_conn: mock.Mock) -> None: request = build_http_request(b'GET', b'http://example.org/get', headers={ b'Host': b'example.org', }) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() upstream = urlparse.urlsplit( RedirectToCustomServerPlugin.UPSTREAM_SERVER) mock_server_conn.assert_called_with('localhost', 8899) mock_server_conn.return_value.queue.assert_called_with( build_http_request(b'GET', upstream.path, headers={ b'Host': upstream.netloc, b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE, })) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_filter_by_upstream_host_plugin( self, mock_server_conn: mock.Mock) -> None: request = build_http_request(b'GET', b'http://google.com/', headers={ b'Host': b'google.com', }) self._conn.recv.return_value = request self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() mock_server_conn.assert_not_called() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response( status_code=httpStatusCodes.I_AM_A_TEAPOT, reason=b'I\'m a tea pot', headers={b'Connection': b'close'}, )) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_man_in_the_middle_plugin(self, mock_server_conn: mock.Mock) -> None: request = build_http_request(b'GET', b'http://super.secure/', headers={ b'Host': b'super.secure', }) self._conn.recv.return_value = request server = mock_server_conn.return_value server.connect.return_value = True def has_buffer() -> bool: return cast(bool, server.queue.called) def closed() -> bool: return not server.connect.called server.has_buffer.side_effect = has_buffer type(server).closed = mock.PropertyMock(side_effect=closed) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=server.connection, fd=server.connection.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], [(selectors.SelectorKey(fileobj=server.connection, fd=server.connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] # Client read self.protocol_handler.run_once() mock_server_conn.assert_called_with('super.secure', DEFAULT_HTTP_PORT) server.connect.assert_called_once() queued_request = \ build_http_request( b'GET', b'/', headers={ b'Host': b'super.secure', b'Via': b'1.1 %s' % PROXY_AGENT_HEADER_VALUE } ) server.queue.assert_called_once_with(queued_request) # Server write self.protocol_handler.run_once() server.flush.assert_called_once() # Server read server.recv.return_value = \ build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Original Response From Upstream') self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response(httpStatusCodes.OK, reason=b'OK', body=b'Hello from man in the middle'))
class TestHttpProxyPluginExamplesWithTlsInterception(unittest.TestCase): @mock.patch('ssl.wrap_socket') @mock.patch('ssl.create_default_context') @mock.patch('proxy.http.proxy.server.TcpServerConnection') @mock.patch('proxy.http.proxy.server.gen_public_key') @mock.patch('proxy.http.proxy.server.gen_csr') @mock.patch('proxy.http.proxy.server.sign_csr') @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_sign_csr: mock.Mock, mock_gen_csr: mock.Mock, mock_gen_public_key: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_sign_csr = mock_sign_csr self.mock_gen_csr = mock_gen_csr self.mock_gen_public_key = mock_gen_public_key self.mock_server_conn = mock_server_conn self.mock_ssl_context = mock_ssl_context self.mock_ssl_wrap = mock_ssl_wrap self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(self._testMethodName) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mock.MagicMock(spec=socket.socket) mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called) def closed() -> bool: return not self.server.connect.called def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return self.server_ssl_connection return self._conn self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mock.PropertyMock(side_effect=closed) type(self.server).connection = mock.PropertyMock( side_effect=mock_connection) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.client_ssl_connection, fd=self.client_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] # Connect def send(raw: bytes) -> int: return len(raw) self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443') self.protocol_handler.run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual(self.protocol_handler.client.connection, self.client_ssl_connection) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) self.assertFalse(self.protocol_handler.client.has_buffer()) def test_modify_post_data_plugin(self) -> None: original = b'{"key": "value"}' modified = b'{"key": "modified"}' self.client_ssl_connection.recv.return_value = build_http_request( b'POST', b'/', headers={ b'Host': b'uni.corn', b'Content-Type': b'application/x-www-form-urlencoded', b'Content-Length': bytes_(len(original)), }, body=original) self.protocol_handler.run_once() self.server.queue.assert_called_with( build_http_request(b'POST', b'/', headers={ b'Host': b'uni.corn', b'Content-Length': bytes_(len(modified)), b'Content-Type': b'application/json', }, body=modified)) def test_man_in_the_middle_plugin(self) -> None: request = build_http_request(b'GET', b'/', headers={ b'Host': b'uni.corn', }) self.client_ssl_connection.recv.return_value = request # Client read self.protocol_handler.run_once() self.server.queue.assert_called_once_with(request) # Server write self.protocol_handler.run_once() self.server.flush.assert_called_once() # Server read self.server.recv.return_value = \ build_http_response( httpStatusCodes.OK, reason=b'OK', body=b'Original Response From Upstream') self.protocol_handler.run_once() self.assertEqual( self.protocol_handler.client.buffer[0].tobytes(), build_http_response(httpStatusCodes.OK, reason=b'OK', body=b'Hello from man in the middle'))
class TestWebServerPlugin(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value self.mock_selector = mock_selector self.flags = Flags() self.flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_disk(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: pac_file = os.path.join(os.path.dirname(PROXY_PY_DIR), 'helper', 'proxy.pac') self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) self.init_and_make_pac_file_request(pac_file) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) with open(pac_file, 'rb') as f: self._conn.send.called_once_with( build_http_response(200, reason=b'OK', headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', b'Connection': b'close' }, body=f.read())) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_pac_file_served_from_buffer(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) pac_file_content = b'function FindProxyForURL(url, host) { return "PROXY localhost:8899; DIRECT"; }' self.init_and_make_pac_file_request(text_(pac_file_content)) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) self._conn.send.called_once_with( build_http_response(200, reason=b'OK', headers={ b'Content-Type': b'application/x-ns-proxy-autoconfig', b'Connection': b'close' }, body=pac_file_content)) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_default_web_server_returns_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ] flags = Flags() flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', CRLF, ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) self.assertEqual(self.protocol_handler.client.buffer[0], HttpWebServerPlugin.DEFAULT_404_RESPONSE) @unittest.skipIf(os.environ.get( 'GITHUB_ACTIONS', False ), 'Disabled on GitHub actions because this test is flaky on GitHub infrastructure.' ) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>''' os.makedirs(static_server_dir, exist_ok=True) with open(index_file_path, 'wb') as f: f.write(html_file_content) self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/index.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True, static_server_dir=static_server_dir) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) encoded_html_file_content = gzip.compress(html_file_content) self.assertEqual( self._conn.send.call_args[0][0], build_http_response(200, reason=b'OK', headers={ b'Content-Type': b'text/html', b'Cache-Control': b'max-age=86400', b'Content-Encoding': b'gzip', b'Connection': b'close', b'Content-Length': bytes_(len(encoded_html_file_content)), }, body=encoded_html_file_content)) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_static_web_server_serves_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/not-found.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) self.assertEqual(self._conn.send.call_args[0][0], HttpWebServerPlugin.DEFAULT_404_RESPONSE) @mock.patch('socket.fromfd') def test_on_client_connection_called_on_teardown( self, mock_fromfd: mock.Mock) -> None: flags = Flags() plugin = mock.MagicMock() flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() plugin.assert_called() with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once: mock_run_once.return_value = True self.protocol_handler.run() self.assertTrue(self._conn.closed) plugin.return_value.on_client_connection_close.assert_called() def init_and_make_pac_file_request(self, pac_file: str) -> None: flags = Flags(pac_file=pac_file) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,' b'proxy.http.server.HttpWebServerPacFilePlugin') self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ]) def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: mock_selector.return_value.select.return_value = [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ]
class TestHttpProxyPlugin(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Proxy.initialize() self.plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin] } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() def test_proxy_plugin_initialized(self) -> None: self.plugin.assert_called() @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proxy_plugin_on_and_before_upstream_connection( self, mock_server_conn: mock.Mock) -> None: self.plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.plugin.return_value.handle_client_request.side_effect = lambda r: r self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host' }) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey( fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() mock_server_conn.assert_called_with('upstream.host', DEFAULT_HTTP_PORT) self.plugin.return_value.before_upstream_connection.assert_called() self.plugin.return_value.handle_client_request.assert_called() @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_proxy_plugin_before_upstream_connection_can_teardown( self, mock_server_conn: mock.Mock) -> None: self.plugin.return_value.before_upstream_connection.side_effect = HttpProtocolException() self._conn.recv.return_value = build_http_request( b'GET', b'http://upstream.host/not-found.html', headers={ b'Host': b'upstream.host' }) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey( fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() self.plugin.return_value.before_upstream_connection.assert_called() mock_server_conn.assert_not_called()
class TestHttpProtocolHandler(unittest.TestCase): @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value self.http_server_port = 65535 self.flags = Proxy.initialize() self.flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.mock_selector = mock_selector self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_http_get(self, mock_server_connection: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self.mock_selector_for_client_read_read_server_write( self.mock_selector, server) # Send request line assert self.http_server_port is not None self._conn.recv.return_value = (b'GET http://localhost:%d HTTP/1.1' % self.http_server_port) + CRLF self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.LINE_RCVD) self.assertNotEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) # Send headers and blank line, thus completing HTTP request assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', CRLF ]) self.assert_data_queued(mock_server_connection, server) self.protocol_handler.run_once() server.flush.assert_called_once() def assert_tunnel_response(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: self.protocol_handler.run_once() self.assertTrue( cast(HttpProxyPlugin, self.protocol_handler. plugins['HttpProxyPlugin']).server is not None) self.assertEqual(self.protocol_handler.client.buffer[0], HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.queue.assert_not_called() server.closed = False parser = HttpParser(httpParserTypes.RESPONSE_PARSER) parser.parse(self.protocol_handler.client.buffer[0].tobytes()) self.assertEqual(parser.state, httpParserStates.COMPLETE) assert parser.code is not None self.assertEqual(int(parser.code), 200) @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_http_tunnel(self, mock_server_connection: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True def has_buffer() -> bool: return cast(bool, server.queue.called) server.has_buffer.side_effect = has_buffer self.mock_selector.return_value.select.side_effect = [ [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ], [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=0, data=None), selectors.EVENT_WRITE), ], [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ], [ (selectors.SelectorKey(fileobj=server.connection, fd=server.connection.fileno, events=0, data=None), selectors.EVENT_WRITE), ], ] assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', CRLF ]) self.assert_tunnel_response(mock_server_connection, server) # Dispatch tunnel established response to client self.protocol_handler.run_once() self.assert_data_queued_to_server(server) self.protocol_handler.run_once() self.assertEqual(server.queue.call_count, 1) server.flush.assert_called_once() def test_proxy_connection_failed(self) -> None: self.mock_selector_for_client_read(self.mock_selector) self._conn.recv.return_value = CRLF.join([ b'GET http://unknown.domain HTTP/1.1', b'Host: unknown.domain', CRLF ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.client.buffer[0], ProxyConnectionFailed.RESPONSE_PKT) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') def test_proxy_authentication_failed(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), bytes_(PLUGIN_PROXY_AUTH), ]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_authenticated_proxy_http_get(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins([ bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER), ]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.INITIALIZED) self._conn.recv.return_value = CRLF self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.LINE_RCVD) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_data_queued(mock_server_connection, server) @mock.patch('selectors.DefaultSelector') @mock.patch('socket.fromfd') @mock.patch('proxy.http.proxy.server.TcpServerConnection') def test_authenticated_proxy_http_tunnel(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = mock_fromfd.return_value self.mock_selector_for_client_read_read_server_write( mock_selector, server) flags = Proxy.initialize(auth_code=base64.b64encode(b'user:pass')) flags.plugins = Proxy.load_plugins( [bytes_(PLUGIN_HTTP_PROXY), bytes_(PLUGIN_WEB_SERVER)]) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_tunnel_response(mock_server_connection, server) self.protocol_handler.client.flush() self.assert_data_queued_to_server(server) self.protocol_handler.run_once() server.flush.assert_called_once() def mock_selector_for_client_read_read_server_write( self, mock_selector: mock.Mock, server: mock.Mock) -> None: mock_selector.return_value.select.side_effect = [ [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ], [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=0, data=None), selectors.EVENT_READ), ], [ (selectors.SelectorKey(fileobj=server.connection, fd=server.connection.fileno, events=0, data=None), selectors.EVENT_WRITE), ], ] def assert_data_queued(self, mock_server_connection: mock.Mock, server: mock.Mock) -> None: self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) mock_server_connection.assert_called_once() server.connect.assert_called_once() server.closed = False assert self.http_server_port is not None pkt = CRLF.join([ b'GET / HTTP/1.1', b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Via: 1.1 proxy.py v%s' % bytes_(__version__), CRLF ]) server.queue.assert_called_once_with(pkt) server.buffer_size.return_value = len(pkt) def assert_data_queued_to_server(self, server: mock.Mock) -> None: assert self.http_server_port is not None self.assertEqual(self._conn.send.call_args[0][0], HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) pkt = CRLF.join([ b'GET / HTTP/1.1', b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), CRLF ]) self._conn.recv.return_value = pkt self.protocol_handler.run_once() server.queue.assert_called_once_with(pkt) server.buffer_size.return_value = len(pkt) server.flush.assert_not_called() def mock_selector_for_client_read(self, mock_selector: mock.Mock) -> None: mock_selector.return_value.select.return_value = [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ]
def test_e2e(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_popen = mock_popen self.mock_server_conn = mock_server_conn self.mock_ssl_context = mock_ssl_context self.mock_ssl_wrap = mock_ssl_wrap ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return ssl_connection return plain_connection type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.plugin.assert_called() self.assertEqual(self.plugin.call_args[0][1], self.flags) self.assertEqual(self.plugin.call_args[0][2].connection, self._conn) self.proxy_plugin.assert_called() self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags) self.assertEqual(self.proxy_plugin.call_args[0][2].connection, self._conn) connect_request = build_http_request(httpMethods.CONNECT, bytes_(netloc), headers={ b'Host': bytes_(netloc), }) self._conn.recv.return_value = connect_request # Prepare mocked HttpProtocolHandlerPlugin self.plugin.return_value.get_descriptors.return_value = ([], []) self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.on_client_data.side_effect = lambda raw: raw self.plugin.return_value.on_request_complete.return_value = False self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk self.plugin.return_value.on_client_connection_close.return_value = None # Prepare mocked HttpProxyBasePlugin self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() # Assert our mocked plugins invocations self.plugin.return_value.get_descriptors.assert_called() self.plugin.return_value.write_to_descriptors.assert_called_with([]) self.plugin.return_value.on_client_data.assert_called_with( connect_request) self.plugin.return_value.on_request_complete.assert_called() self.plugin.return_value.read_from_descriptors.assert_called_with( [self._conn]) self.proxy_plugin.return_value.before_upstream_connection.assert_called( ) self.proxy_plugin.return_value.handle_client_request.assert_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( False) self.mock_ssl_context.assert_called_with(ssl.Purpose.SERVER_AUTH) # self.assertEqual(self.mock_ssl_context.return_value.options, # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host) # TODO: Assert Popen arguments, piping, success condition self.assertEqual(self.mock_popen.call_count, 2) self.assertEqual(ssl_connection.setblocking.call_count, 1) self.assertEqual(self.mock_server_conn.return_value._conn, ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( self.flags.ca_cert_dir, host)) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual(self.protocol_handler.client.connection, self.mock_ssl_wrap.return_value) # Assert connection references for all other plugins is updated self.assertEqual(self.plugin.return_value.client._conn, self.mock_ssl_wrap.return_value) self.assertEqual(self.proxy_plugin.return_value.client._conn, self.mock_ssl_wrap.return_value)