예제 #1
0
def _do_start_ssl_proxy(port: int, target: PortOrUrl, target_ssl=False):
    import pproxy

    from localstack.services.generic_proxy import GenericProxy

    if ":" not in str(target):
        target = "127.0.0.1:%s" % target
    LOG.debug("Starting SSL proxy server %s -> %s", port, target)

    # create server and remote connection
    server = pproxy.Server("secure+tunnel://0.0.0.0:%s" % port)
    target_proto = "ssl+tunnel" if target_ssl else "tunnel"
    remote = pproxy.Connection("%s://%s" % (target_proto, target))
    args = dict(rserver=[remote])

    # set SSL contexts
    _, cert_file_name, key_file_name = GenericProxy.create_ssl_cert()
    for context in pproxy.server.sslcontexts:
        context.load_cert_chain(cert_file_name, key_file_name)

    loop = ensure_event_loop()
    handler = loop.run_until_complete(server.start_server(args))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("exit!")

    handler.close()
    loop.run_until_complete(handler.wait_closed())
    loop.run_until_complete(loop.shutdown_asyncgens())
    loop.close()
예제 #2
0
def _do_start_ssl_proxy(
    port: int,
    target: PortOrUrl,
    target_ssl=False,
    client_cert_key: Tuple[str, str] = None,
    bind_address: str = "0.0.0.0",
):
    """
    Starts a tcp proxy (with tls) on the specified port

    :param port: Port the proxy should bind to
    :param target: Target of the proxy. If a port, it will connect to localhost:
    :param target_ssl: Specify if the proxy should connect to the target using SSL/TLS
    :param client_cert_key: Client certificate for the target connection. Only set if target_ssl=True
    :param bind_address: Bind address of the proxy server
    """
    import pproxy

    from localstack.services.generic_proxy import GenericProxy

    if ":" not in str(target):
        target = f"127.0.0.1:{target}"
    LOG.debug("Starting SSL proxy server %s -> %s", port, target)

    # create server and remote connection
    server = pproxy.Server(f"secure+tunnel://{bind_address}:{port}")
    target_proto = "ssl+tunnel" if target_ssl else "tunnel"
    remote = pproxy.Connection(f"{target_proto}://{target}")
    if client_cert_key:
        # TODO verify client certs server side?
        LOG.debug("Configuring ssl proxy to use client certs")
        cert_file, key_file = _save_cert_keys(client_cert_key=client_cert_key)
        remote.sslclient.load_cert_chain(certfile=cert_file, keyfile=key_file)
    args = dict(rserver=[remote])

    # set SSL contexts
    _, cert_file_name, key_file_name = GenericProxy.create_ssl_cert()
    for context in pproxy.server.sslcontexts:
        context.load_cert_chain(cert_file_name, key_file_name)

    loop = ensure_event_loop()
    handler = loop.run_until_complete(server.start_server(args))
    try:
        loop.run_forever()
    except KeyboardInterrupt:
        print("exit!")

    handler.close()
    loop.run_until_complete(handler.wait_closed())
    loop.run_until_complete(loop.shutdown_asyncgens())
    loop.close()
예제 #3
0
 def run_app_sync(*args, loop=None, shutdown_event=None):
     kwargs = {}
     config = Config()
     cert_file_name, key_file_name = ssl_creds or (None, None)
     if cert_file_name:
         kwargs["certfile"] = cert_file_name
         config.certfile = cert_file_name
     if key_file_name:
         kwargs["keyfile"] = key_file_name
         config.keyfile = key_file_name
     setup_quart_logging()
     config.bind = [f"{bind_address}:{port}" for bind_address in bind_addresses]
     config.workers = len(bind_addresses)
     loop = loop or ensure_event_loop()
     run_kwargs = {}
     if shutdown_event:
         run_kwargs["shutdown_trigger"] = shutdown_event.wait
     try:
         try:
             return loop.run_until_complete(serve(app, config, **run_kwargs))
         except Exception as e:
             LOG.info(
                 "Error running server event loop on port %s: %s %s",
                 port,
                 e,
                 traceback.format_exc(),
             )
             if "SSL" in str(e):
                 c_exists = os.path.exists(cert_file_name)
                 k_exists = os.path.exists(key_file_name)
                 c_size = len(load_file(cert_file_name)) if c_exists else 0
                 k_size = len(load_file(key_file_name)) if k_exists else 0
                 LOG.warning(
                     "Unable to create SSL context. Cert files exist: %s %s (%sB), %s %s (%sB)",
                     cert_file_name,
                     c_exists,
                     c_size,
                     key_file_name,
                     k_exists,
                     k_size,
                 )
             raise
     finally:
         try:
             _cancel_all_tasks(loop)
             loop.run_until_complete(loop.shutdown_asyncgens())
         finally:
             asyncio.set_event_loop(None)
             loop.close()
예제 #4
0
 def run_proxy(self, *args):
     self.loop = ensure_event_loop()
     self.shutdown_event = asyncio.Event()
     run_app_sync(loop=self.loop, shutdown_event=self.shutdown_event)
예제 #5
0
def run_server(
    port: int,
    bind_address: str,
    handler: Callable = None,
    asynchronous: bool = True,
    ssl_creds: Tuple[str, str] = None,
    max_content_length: int = None,
    send_timeout: int = None,
):
    """
    Run an HTTP2-capable Web server on the given port, processing incoming requests via a `handler` function.
    :param port: port to bind to
    :param bind_address: address to bind to
    :param handler: callable that receives the request and returns a response
    :param asynchronous: whether to start the server asynchronously in the background
    :param ssl_creds: optional tuple with SSL cert file names (cert file, key file)
    :param max_content_length: maximum content length of uploaded payload
    :param send_timeout: timeout (in seconds) for sending the request payload over the wire
    """

    ensure_event_loop()
    app = Quart(__name__, static_folder=None)
    app.config["MAX_CONTENT_LENGTH"] = max_content_length or DEFAULT_MAX_CONTENT_LENGTH
    if send_timeout:
        app.config["BODY_TIMEOUT"] = send_timeout

    @app.route("/", methods=HTTP_METHODS, defaults={"path": ""})
    @app.route("/<path:path>", methods=HTTP_METHODS)
    async def index(path=None):
        response = await make_response("{}")
        if handler:
            data = await request.get_data()
            try:
                result = await run_sync(handler, request, data)
                if isinstance(result, Exception):
                    raise result
            except Exception as e:
                LOG.warning(
                    "Error in proxy handler for request %s %s: %s %s",
                    request.method,
                    request.url,
                    e,
                    traceback.format_exc(),
                )
                response.status_code = 500
                if isinstance(e, HTTPErrorResponse):
                    response.status_code = e.code or response.status_code
                return response
            if result is not None:
                # check if this is an async generator (for HTTP2 push event responses)
                async_gen = get_async_generator_result(result)
                if async_gen:
                    return async_gen
                # prepare and return regular response
                is_chunked = uses_chunked_encoding(result)
                result_content = result.content or ""
                response = await make_response(result_content)
                response.status_code = result.status_code
                if is_chunked:
                    response.headers.pop("Content-Length", None)
                result.headers.pop("Server", None)
                result.headers.pop("Date", None)
                headers = {k: str(v).replace("\n", r"\n") for k, v in result.headers.items()}
                response.headers.update(headers)
                # set multi-value headers
                multi_value_headers = getattr(result, "multi_value_headers", {})
                for key, values in multi_value_headers.items():
                    for value in values:
                        response.headers.add_header(key, value)
                # set default headers, if required
                if not is_chunked and request.method not in ["OPTIONS", "HEAD"]:
                    response_data = await response.get_data()
                    response.headers["Content-Length"] = str(len(response_data or ""))
                if "Connection" not in response.headers:
                    response.headers["Connection"] = "close"
                # fix headers for OPTIONS requests (possible fix for Firefox requests)
                if request.method == "OPTIONS":
                    response.headers.pop("Content-Type", None)
                    if not response.headers.get("Cache-Control"):
                        response.headers["Cache-Control"] = "no-cache"
        return response

    def run_app_sync(*args, loop=None, shutdown_event=None):
        kwargs = {}
        config = Config()
        cert_file_name, key_file_name = ssl_creds or (None, None)
        if cert_file_name:
            kwargs["certfile"] = cert_file_name
            config.certfile = cert_file_name
        if key_file_name:
            kwargs["keyfile"] = key_file_name
            config.keyfile = key_file_name
        setup_quart_logging()
        config.bind = [f"{bind_address}:{port}"]
        loop = loop or ensure_event_loop()
        run_kwargs = {}
        if shutdown_event:
            run_kwargs["shutdown_trigger"] = shutdown_event.wait
        try:
            try:
                return loop.run_until_complete(serve(app, config, **run_kwargs))
            except Exception as e:
                LOG.info(
                    "Error running server event loop on port %s: %s %s",
                    port,
                    e,
                    traceback.format_exc(),
                )
                if "SSL" in str(e):
                    c_exists = os.path.exists(cert_file_name)
                    k_exists = os.path.exists(key_file_name)
                    c_size = len(load_file(cert_file_name)) if c_exists else 0
                    k_size = len(load_file(key_file_name)) if k_exists else 0
                    LOG.warning(
                        "Unable to create SSL context. Cert files exist: %s %s (%sB), %s %s (%sB)",
                        cert_file_name,
                        c_exists,
                        c_size,
                        key_file_name,
                        k_exists,
                        k_size,
                    )
                raise
        finally:
            try:
                _cancel_all_tasks(loop)
                loop.run_until_complete(loop.shutdown_asyncgens())
            finally:
                asyncio.set_event_loop(None)
                loop.close()

    class ProxyThread(FuncThread):
        def __init__(self):
            FuncThread.__init__(self, self.run_proxy, None)
            self.shutdown_event = None
            self.loop = None

        def run_proxy(self, *args):
            self.loop = ensure_event_loop()
            self.shutdown_event = asyncio.Event()
            run_app_sync(loop=self.loop, shutdown_event=self.shutdown_event)

        def stop(self, quiet=None):
            event = self.shutdown_event

            async def set_event():
                event.set()

            run_coroutine(set_event(), self.loop)
            super().stop(quiet)

    def run_in_thread():
        thread = ProxyThread()
        thread.start()
        TMP_THREADS.append(thread)
        return thread

    if asynchronous:
        return run_in_thread()

    return run_app_sync()