def test_default_web_server_returns_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value mock_selector.return_value.select.return_value = [ (selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ), ] flags = Flags() flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET /hello HTTP/1.1', CRLF, ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.COMPLETE) self.assertEqual(self.protocol_handler.client.buffer[0], HttpWebServerPlugin.DEFAULT_404_RESPONSE)
def test_static_web_server_serves_404(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/not-found.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) self.assertEqual(self._conn.send.call_args[0][0], HttpWebServerPlugin.DEFAULT_404_RESPONSE)
def test_authenticated_proxy_http_tunnel(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 self._conn = mock_fromfd.return_value self.mock_selector_for_client_read_read_server_write( mock_selector, server) flags = Flags(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'CONNECT localhost:%d HTTP/1.1' % self.http_server_port, b'Host: localhost:%d' % self.http_server_port, b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_tunnel_response(mock_server_connection, server) self.protocol_handler.client.flush() self.assert_data_queued_to_server(server) self.protocol_handler.run_once() server.flush.assert_called_once()
def test_static_web_server_serves(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: # Setup a static directory static_server_dir = os.path.join(tempfile.gettempdir(), 'static') index_file_path = os.path.join(static_server_dir, 'index.html') html_file_content = b'''<html><head></head><body><h1>Proxy.py Testing</h1></body></html>''' os.makedirs(static_server_dir, exist_ok=True) with open(index_file_path, 'wb') as f: f.write(html_file_content) self._conn = mock_fromfd.return_value self._conn.recv.return_value = build_http_request( b'GET', b'/index.html') mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], ] flags = Flags(enable_static_server=True, static_server_dir=static_server_dir) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self.protocol_handler.run_once() self.protocol_handler.run_once() self.assertEqual(mock_selector.return_value.select.call_count, 2) self.assertEqual(self._conn.send.call_count, 1) encoded_html_file_content = gzip.compress(html_file_content) self.assertEqual( self._conn.send.call_args[0][0], build_http_response(200, reason=b'OK', headers={ b'Content-Type': b'text/html', b'Cache-Control': b'max-age=86400', b'Content-Encoding': b'gzip', b'Connection': b'close', b'Content-Length': bytes_(len(encoded_html_file_content)), }, body=encoded_html_file_content))
def setUp(self) -> None: self.acceptor_id = 1 self.mock_protocol_handler = mock.MagicMock() self.pipe = multiprocessing.Pipe() self.flags = Flags() self.acceptor = Acceptor(idd=self.acceptor_id, work_queue=self.pipe[1], flags=self.flags, lock=multiprocessing.Lock(), work_klass=self.mock_protocol_handler)
def main() -> None: # This example requires `threadless=True` pool = AcceptorPool(flags=Flags(port=12345, num_workers=1, threadless=True), work_klass=EchoServerHandler) try: pool.setup() while True: time.sleep(1) finally: pool.shutdown()
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.fileno = 10 self._addr = ('127.0.0.1', 54382) self._conn = mock_fromfd.return_value self.mock_selector = mock_selector self.flags = Flags() self.flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize()
def init_and_make_pac_file_request(self, pac_file: str) -> None: flags = Flags(pac_file=pac_file) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin,' b'proxy.http.server.HttpWebServerPacFilePlugin') self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET / HTTP/1.1', CRLF, ])
def main() -> None: # This example requires `threadless=True` pool = AcceptorPool(flags=Flags(port=12345, num_workers=1, threadless=True, keyfile='https-key.pem', certfile='https-signed-cert.pem'), work_klass=EchoSSLServerHandler) try: pool.setup() while True: time.sleep(1) finally: pool.shutdown()
def test_on_client_connection_called_on_teardown( self, mock_fromfd: mock.Mock) -> None: flags = Flags() plugin = mock.MagicMock() flags.plugins = {b'HttpProtocolHandlerPlugin': [plugin]} self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() plugin.assert_called() with mock.patch.object(self.protocol_handler, 'run_once') as mock_run_once: mock_run_once.return_value = True self.protocol_handler.run() self.assertTrue(self._conn.closed) plugin.return_value.on_client_connection_close.assert_called()
def test_proxy_authentication_failed(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) flags = Flags(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() self._conn.recv.return_value = CRLF.join([ b'GET http://abhinavsingh.com HTTP/1.1', b'Host: abhinavsingh.com', CRLF ]) self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.client.buffer[0], ProxyAuthenticationFailed.RESPONSE_PKT)
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags() self.plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [self.plugin] } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler( TcpClientConnection(self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize()
def test_authenticated_proxy_http_get(self, mock_server_connection: mock.Mock, mock_fromfd: mock.Mock, mock_selector: mock.Mock) -> None: self._conn = mock_fromfd.return_value self.mock_selector_for_client_read(mock_selector) server = mock_server_connection.return_value server.connect.return_value = True server.buffer_size.return_value = 0 flags = Flags(auth_code=b'Basic %s' % base64.b64encode(b'user:pass')) flags.plugins = Flags.load_plugins( b'proxy.http.proxy.HttpProxyPlugin,proxy.http.server.HttpWebServerPlugin' ) self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=flags) self.protocol_handler.initialize() assert self.http_server_port is not None self._conn.recv.return_value = b'GET http://localhost:%d HTTP/1.1' % self.http_server_port self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.INITIALIZED) self._conn.recv.return_value = CRLF self.protocol_handler.run_once() self.assertEqual(self.protocol_handler.request.state, httpParserStates.LINE_RCVD) assert self.http_server_port is not None self._conn.recv.return_value = CRLF.join([ b'User-Agent: proxy.py/%s' % bytes_(__version__), b'Host: localhost:%d' % self.http_server_port, b'Accept: */*', b'Proxy-Connection: Keep-Alive', b'Proxy-Authorization: Basic dXNlcjpwYXNz', CRLF ]) self.assert_data_queued(mock_server_connection, server)
def test_setup_and_shutdown(self, mock_acceptor: mock.Mock, mock_socket: mock.Mock, mock_pipe: mock.Mock, mock_send_handle: mock.Mock) -> None: acceptor1 = mock.MagicMock() acceptor2 = mock.MagicMock() mock_acceptor.side_effect = [acceptor1, acceptor2] num_workers = 2 sock = mock_socket.return_value work_klass = mock.MagicMock() flags = Flags(num_workers=2) pool = AcceptorPool(flags=flags, work_klass=work_klass) pool.setup() mock_send_handle.assert_called() work_klass.assert_not_called() mock_socket.assert_called_with( socket.AF_INET6 if pool.flags.hostname.version == 6 else socket.AF_INET, socket.SOCK_STREAM) sock.setsockopt.assert_called_with(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind.assert_called_with( (str(pool.flags.hostname), pool.flags.port)) sock.listen.assert_called_with(pool.flags.backlog) sock.setblocking.assert_called_with(False) self.assertTrue(mock_pipe.call_count, num_workers) self.assertTrue(mock_acceptor.call_count, num_workers) acceptor1.start.assert_called() acceptor2.start.assert_called() acceptor1.join.assert_not_called() acceptor2.join.assert_not_called() sock.close.assert_called() pool.shutdown() acceptor1.join.assert_called() acceptor2.join.assert_called()
def setUp(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_sign_csr: mock.Mock, mock_gen_csr: mock.Mock, mock_gen_public_key: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None: self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_sign_csr = mock_sign_csr self.mock_gen_csr = mock_gen_csr self.mock_gen_public_key = mock_gen_public_key self.mock_server_conn = mock_server_conn self.mock_ssl_context = mock_ssl_context self.mock_ssl_wrap = mock_ssl_wrap self.mock_sign_csr.return_value = True self.mock_gen_csr.return_value = True self.mock_gen_public_key.return_value = True self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() plugin = get_plugin_by_test_name(self._testMethodName) self.flags.plugins = { b'HttpProtocolHandlerPlugin': [HttpProxyPlugin], b'HttpProxyBasePlugin': [plugin], } self._conn = mock.MagicMock(spec=socket.socket) mock_fromfd.return_value = self._conn self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.server = self.mock_server_conn.return_value self.server_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = self.server_ssl_connection self.client_ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_wrap.return_value = self.client_ssl_connection def has_buffer() -> bool: return cast(bool, self.server.queue.called) def closed() -> bool: return not self.server.connect.called def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return self.server_ssl_connection return self._conn self.server.has_buffer.side_effect = has_buffer type(self.server).closed = mock.PropertyMock(side_effect=closed) type(self.server).connection = mock.PropertyMock( side_effect=mock_connection) self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.client_ssl_connection, fd=self.client_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_WRITE, data=None), selectors.EVENT_WRITE)], [(selectors.SelectorKey(fileobj=self.server_ssl_connection, fd=self.server_ssl_connection.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] # Connect def send(raw: bytes) -> int: return len(raw) self._conn.send.side_effect = send self._conn.recv.return_value = build_http_request( httpMethods.CONNECT, b'uni.corn:443') self.protocol_handler.run_once() self.assertEqual(self.mock_sign_csr.call_count, 1) self.assertEqual(self.mock_gen_csr.call_count, 1) self.assertEqual(self.mock_gen_public_key.call_count, 1) self.mock_server_conn.assert_called_once_with('uni.corn', 443) self.server.connect.assert_called() self.assertEqual(self.protocol_handler.client.connection, self.client_ssl_connection) self.assertEqual(self.server.connection, self.server_ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) self.assertFalse(self.protocol_handler.client.has_buffer())
def test_e2e(self, mock_fromfd: mock.Mock, mock_selector: mock.Mock, mock_popen: mock.Mock, mock_server_conn: mock.Mock, mock_ssl_context: mock.Mock, mock_ssl_wrap: mock.Mock) -> None: host, port = uuid.uuid4().hex, 443 netloc = '{0}:{1}'.format(host, port) self.mock_fromfd = mock_fromfd self.mock_selector = mock_selector self.mock_popen = mock_popen self.mock_server_conn = mock_server_conn self.mock_ssl_context = mock_ssl_context self.mock_ssl_wrap = mock_ssl_wrap ssl_connection = mock.MagicMock(spec=ssl.SSLSocket) self.mock_ssl_context.return_value.wrap_socket.return_value = ssl_connection self.mock_ssl_wrap.return_value = mock.MagicMock(spec=ssl.SSLSocket) plain_connection = mock.MagicMock(spec=socket.socket) def mock_connection() -> Any: if self.mock_ssl_context.return_value.wrap_socket.called: return ssl_connection return plain_connection type(self.mock_server_conn.return_value).connection = \ mock.PropertyMock(side_effect=mock_connection) self.fileno = 10 self._addr = ('127.0.0.1', 54382) self.flags = Flags( ca_cert_file='ca-cert.pem', ca_key_file='ca-key.pem', ca_signing_key_file='ca-signing-key.pem', ) self.plugin = mock.MagicMock() self.proxy_plugin = mock.MagicMock() self.flags.plugins = { b'HttpProtocolHandlerPlugin': [self.plugin, HttpProxyPlugin], b'HttpProxyBasePlugin': [self.proxy_plugin], } self._conn = mock_fromfd.return_value self.protocol_handler = HttpProtocolHandler(TcpClientConnection( self._conn, self._addr), flags=self.flags) self.protocol_handler.initialize() self.plugin.assert_called() self.assertEqual(self.plugin.call_args[0][1], self.flags) self.assertEqual(self.plugin.call_args[0][2].connection, self._conn) self.proxy_plugin.assert_called() self.assertEqual(self.proxy_plugin.call_args[0][1], self.flags) self.assertEqual(self.proxy_plugin.call_args[0][2].connection, self._conn) connect_request = build_http_request(httpMethods.CONNECT, bytes_(netloc), headers={ b'Host': bytes_(netloc), }) self._conn.recv.return_value = connect_request # Prepare mocked HttpProtocolHandlerPlugin self.plugin.return_value.get_descriptors.return_value = ([], []) self.plugin.return_value.write_to_descriptors.return_value = False self.plugin.return_value.read_from_descriptors.return_value = False self.plugin.return_value.on_client_data.side_effect = lambda raw: raw self.plugin.return_value.on_request_complete.return_value = False self.plugin.return_value.on_response_chunk.side_effect = lambda chunk: chunk self.plugin.return_value.on_client_connection_close.return_value = None # Prepare mocked HttpProxyBasePlugin self.proxy_plugin.return_value.before_upstream_connection.side_effect = lambda r: r self.proxy_plugin.return_value.handle_client_request.side_effect = lambda r: r self.mock_selector.return_value.select.side_effect = [ [(selectors.SelectorKey(fileobj=self._conn, fd=self._conn.fileno, events=selectors.EVENT_READ, data=None), selectors.EVENT_READ)], ] self.protocol_handler.run_once() # Assert our mocked plugins invocations self.plugin.return_value.get_descriptors.assert_called() self.plugin.return_value.write_to_descriptors.assert_called_with([]) self.plugin.return_value.on_client_data.assert_called_with( connect_request) self.plugin.return_value.on_request_complete.assert_called() self.plugin.return_value.read_from_descriptors.assert_called_with( [self._conn]) self.proxy_plugin.return_value.before_upstream_connection.assert_called( ) self.proxy_plugin.return_value.handle_client_request.assert_called() self.mock_server_conn.assert_called_with(host, port) self.mock_server_conn.return_value.connection.setblocking.assert_called_with( False) self.mock_ssl_context.assert_called_with(ssl.Purpose.SERVER_AUTH) # self.assertEqual(self.mock_ssl_context.return_value.options, # ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3 | ssl.OP_NO_TLSv1 | # ssl.OP_NO_TLSv1_1) self.assertEqual(plain_connection.setblocking.call_count, 2) self.mock_ssl_context.return_value.wrap_socket.assert_called_with( plain_connection, server_hostname=host) # TODO: Assert Popen arguments, piping, success condition self.assertEqual(self.mock_popen.call_count, 2) self.assertEqual(ssl_connection.setblocking.call_count, 1) self.assertEqual(self.mock_server_conn.return_value._conn, ssl_connection) self._conn.send.assert_called_with( HttpProxyPlugin.PROXY_TUNNEL_ESTABLISHED_RESPONSE_PKT) assert self.flags.ca_cert_dir is not None self.mock_ssl_wrap.assert_called_with( self._conn, server_side=True, keyfile=self.flags.ca_signing_key_file, certfile=HttpProxyPlugin.generated_cert_file_path( self.flags.ca_cert_dir, host)) self.assertEqual(self._conn.setblocking.call_count, 2) self.assertEqual(self.protocol_handler.client.connection, self.mock_ssl_wrap.return_value) # Assert connection references for all other plugins is updated self.assertEqual(self.plugin.return_value.client._conn, self.mock_ssl_wrap.return_value) self.assertEqual(self.proxy_plugin.return_value.client._conn, self.mock_ssl_wrap.return_value)