Exemplo n.º 1
0
    async def _connect(self, urlparsed: ParseResult, verify: bool,
                       ssl_context: SSLContext, http2: bool):
        """Get reader and writer."""
        key = '%s-%s' % (urlparsed.hostname, urlparsed.port)

        if self.writer:
            # python 3.6 doesn't have writer.is_closing
            is_closing = getattr(
                self.writer, 'is_closing',
                self.writer._transport.is_closing)  # type: ignore
        else:

            def is_closing():
                return True  # noqa

        if not (self.key and key == self.key and not is_closing()):
            if self.writer:
                self.writer.close()

            if urlparsed.scheme == 'https':
                ssl_context = ssl_context or ssl.create_default_context(
                    ssl.Purpose.SERVER_AUTH, )
                if http2:  # flag will be removed when fully http2 support
                    ssl_context.set_alpn_protocols(['h2', 'http/1.1'])
                if not verify:
                    ssl_context.check_hostname = False
                    ssl_context.verify_mode = ssl.CERT_NONE
            port = urlparsed.port or (443
                                      if urlparsed.scheme == 'https' else 80)
            self.reader, self.writer = await asyncio.open_connection(
                urlparsed.hostname, port, ssl=ssl_context)

            self.temp_key = key
            await self._connection_made()
Exemplo n.º 2
0
    async def test_alpn_negotiation(self, server_context: ssl.SSLContext,
                                    client_context: ssl.SSLContext) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            conn.settimeout(1)
            selected_alpn_protocol = conn.selected_alpn_protocol()
            assert selected_alpn_protocol is not None
            conn.send(selected_alpn_protocol.encode())
            conn.close()

        server_context.set_alpn_protocols(["dummy1", "dummy2"])
        client_context.set_alpn_protocols(["dummy2", "dummy3"])

        server_sock = server_context.wrap_socket(socket.socket(),
                                                 server_side=True,
                                                 suppress_ragged_eofs=False)
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(stream,
                                           hostname="localhost",
                                           ssl_context=client_context)
            assert wrapper.extra(TLSAttribute.alpn_protocol) == "dummy2"
            server_alpn_protocol = await wrapper.receive()

        server_thread.join()
        server_sock.close()
        assert server_alpn_protocol == b"dummy2"
Exemplo n.º 3
0
 def _create_ssl_context(self) -> Optional[SSLContext]:
     ssl_context = None
     if self.cfg.is_ssl:
         ssl_context = SSLContext(self.cfg.ssl_version)
         ssl_context.load_cert_chain(self.cfg.certfile, self.cfg.keyfile)
         if self.cfg.ca_certs:
             ssl_context.load_verify_locations(self.cfg.ca_certs)
         if self.cfg.ciphers:
             ssl_context.set_ciphers(self.cfg.ciphers)
         ssl_context.set_alpn_protocols(['h2', 'http/1.1'])
     return ssl_context
Exemplo n.º 4
0
    async def test_extra_attributes(self, server_context: ssl.SSLContext,
                                    client_context: ssl.SSLContext) -> None:
        def serve_sync() -> None:
            conn, addr = server_sock.accept()
            with conn:
                conn.settimeout(1)
                conn.recv(1)

        server_context.set_alpn_protocols(["h2"])
        client_context.set_alpn_protocols(["h2"])

        server_sock = server_context.wrap_socket(socket.socket(),
                                                 server_side=True,
                                                 suppress_ragged_eofs=True)
        server_sock.settimeout(1)
        server_sock.bind(("127.0.0.1", 0))
        server_sock.listen()
        server_thread = Thread(target=serve_sync)
        server_thread.start()

        async with await connect_tcp(*server_sock.getsockname()) as stream:
            wrapper = await TLSStream.wrap(
                stream,
                hostname="localhost",
                ssl_context=client_context,
                standard_compatible=False,
            )
            async with wrapper:
                for name, attribute in SocketAttribute.__dict__.items():
                    if not name.startswith("_"):
                        assert wrapper.extra(attribute) == stream.extra(
                            attribute)

                assert wrapper.extra(TLSAttribute.alpn_protocol) == "h2"
                assert isinstance(
                    wrapper.extra(TLSAttribute.channel_binding_tls_unique),
                    bytes)
                assert isinstance(wrapper.extra(TLSAttribute.cipher), tuple)
                assert isinstance(wrapper.extra(TLSAttribute.peer_certificate),
                                  dict)
                assert isinstance(
                    wrapper.extra(TLSAttribute.peer_certificate_binary), bytes)
                assert wrapper.extra(TLSAttribute.server_side) is False
                assert isinstance(wrapper.extra(TLSAttribute.shared_ciphers),
                                  list)
                assert isinstance(wrapper.extra(TLSAttribute.ssl_object),
                                  ssl.SSLObject)
                assert wrapper.extra(TLSAttribute.standard_compatible) is False
                assert wrapper.extra(
                    TLSAttribute.tls_version).startswith("TLSv")
                await wrapper.send(b"\x00")

        server_thread.join()
        server_sock.close()
Exemplo n.º 5
0
    async def _connect(self, urlparsed: ParseResult, verify: bool,
                       ssl_context: SSLContext, dns_info, http2: bool) -> None:
        """Get reader and writer."""
        if not urlparsed.hostname:
            raise HttpParsingError('missing hostname')

        key = f'{urlparsed.hostname}-{urlparsed.port}'

        if self.writer:
            # python 3.6 doesn't have writer.is_closing
            is_closing = getattr(
                self.writer, 'is_closing',
                self.writer._transport.is_closing)  # type: ignore
        else:

            def is_closing():
                return True  # noqa

        dns_info_copy = dns_info.copy()
        dns_info_copy['server_hostname'] = dns_info_copy.pop('hostname')

        if not (self.key and key == self.key and not is_closing()):
            self.close()

            if urlparsed.scheme == 'https':
                ssl_context = ssl_context or ssl.create_default_context(
                    ssl.Purpose.SERVER_AUTH, )
                # flag will be removed when fully http2 support
                if http2:  # pragma: no cover
                    ssl_context.set_alpn_protocols(['h2', 'http/1.1'])
                if not verify:
                    ssl_context.check_hostname = False
                    ssl_context.verify_mode = ssl.CERT_NONE
            else:
                del dns_info_copy['server_hostname']
            port = urlparsed.port or (443
                                      if urlparsed.scheme == 'https' else 80)
            dns_info_copy['port'] = port
            self.reader, self.writer = await open_connection(
                **dns_info_copy, ssl=ssl_context)

            self.temp_key = key
            await self._connection_made()
Exemplo n.º 6
0
class SyncHTTPConnection(SyncHTTPTransport):
    def __init__(
        self,
        origin: Origin,
        http2: bool = False,
        ssl_context: SSLContext = None,
        socket: SyncSocketStream = None,
    ):
        self.origin = origin
        self.http2 = http2
        self.ssl_context = SSLContext() if ssl_context is None else ssl_context
        self.socket = socket

        if self.http2:
            self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])

        self.connection: Union[None, SyncHTTP11Connection, SyncHTTP2Connection] = None
        self.is_http11 = False
        self.is_http2 = False
        self.connect_failed = False
        self.expires_at: Optional[float] = None
        self.backend = SyncBackend()

    @property
    def request_lock(self) -> SyncLock:
        # We do this lazily, to make sure backend autodetection always
        # runs within an async context.
        if not hasattr(self, "_request_lock"):
            self._request_lock = self.backend.create_lock()
        return self._request_lock

    def request(
        self,
        method: bytes,
        url: URL,
        headers: Headers = None,
        stream: SyncByteStream = None,
        timeout: TimeoutDict = None,
    ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], SyncByteStream]:
        assert url_to_origin(url) == self.origin
        with self.request_lock:
            if self.state == ConnectionState.PENDING:
                if not self.socket:
                    logger.trace(
                        "open_socket origin=%r timeout=%r", self.origin, timeout
                    )
                    self.socket = self._open_socket(timeout)
                self._create_connection(self.socket)
            elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
                pass
            elif self.state == ConnectionState.ACTIVE and self.is_http2:
                pass
            else:
                raise NewConnectionRequired()

        assert self.connection is not None
        logger.trace(
            "connection.request method=%r url=%r headers=%r", method, url, headers
        )
        return self.connection.request(method, url, headers, stream, timeout)

    def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
        scheme, hostname, port = self.origin
        timeout = {} if timeout is None else timeout
        ssl_context = self.ssl_context if scheme == b"https" else None
        try:
            return self.backend.open_tcp_stream(
                hostname, port, ssl_context, timeout
            )
        except Exception:
            self.connect_failed = True
            raise

    def _create_connection(self, socket: SyncSocketStream) -> None:
        http_version = socket.get_http_version()
        logger.trace(
            "create_connection socket=%r http_version=%r", socket, http_version
        )
        if http_version == "HTTP/2":
            self.is_http2 = True
            self.connection = SyncHTTP2Connection(
                socket=socket, backend=self.backend, ssl_context=self.ssl_context
            )
        else:
            self.is_http11 = True
            self.connection = SyncHTTP11Connection(
                socket=socket, ssl_context=self.ssl_context
            )

    @property
    def state(self) -> ConnectionState:
        if self.connect_failed:
            return ConnectionState.CLOSED
        elif self.connection is None:
            return ConnectionState.PENDING
        return self.connection.state

    def is_connection_dropped(self) -> bool:
        return self.connection is not None and self.connection.is_connection_dropped()

    def mark_as_ready(self) -> None:
        if self.connection is not None:
            self.connection.mark_as_ready()

    def start_tls(self, hostname: bytes, timeout: TimeoutDict = None) -> None:
        if self.connection is not None:
            logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
            self.connection.start_tls(hostname, timeout)
            logger.trace("start_tls complete hostname=%r timeout=%r", hostname, timeout)
            self.socket = self.connection.socket

    def close(self) -> None:
        with self.request_lock:
            if self.connection is not None:
                self.connection.close()
Exemplo n.º 7
0
class AsyncHTTPConnection(AsyncHTTPTransport):
    def __init__(
        self,
        origin: Origin,
        http2: bool = False,
        uds: str = None,
        ssl_context: SSLContext = None,
        socket: AsyncSocketStream = None,
        local_address: str = None,
        retries: int = 0,
        backend: AsyncBackend = None,
    ):
        self.origin = origin
        self.http2 = http2
        self.uds = uds
        self.ssl_context = SSLContext() if ssl_context is None else ssl_context
        self.socket = socket
        self.local_address = local_address
        self.retries = retries

        if self.http2:
            self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])

        self.connection: Optional[AsyncBaseHTTPConnection] = None
        self.is_http11 = False
        self.is_http2 = False
        self.connect_failed = False
        self.expires_at: Optional[float] = None
        self.backend = AutoBackend() if backend is None else backend

    def __repr__(self) -> str:
        http_version = "UNKNOWN"
        if self.is_http11:
            http_version = "HTTP/1.1"
        elif self.is_http2:
            http_version = "HTTP/2"
        return f"<AsyncHTTPConnection http_version={http_version} state={self.state}>"

    def info(self) -> str:
        if self.connection is None:
            return "Not connected"
        elif self.state == ConnectionState.PENDING:
            return "Connecting"
        return self.connection.info()

    @property
    def request_lock(self) -> AsyncLock:
        # We do this lazily, to make sure backend autodetection always
        # runs within an async context.
        if not hasattr(self, "_request_lock"):
            self._request_lock = self.backend.create_lock()
        return self._request_lock

    async def arequest(
        self,
        method: bytes,
        url: URL,
        headers: Headers = None,
        stream: AsyncByteStream = None,
        ext: dict = None,
    ) -> Tuple[int, Headers, AsyncByteStream, dict]:
        assert url_to_origin(url) == self.origin
        ext = {} if ext is None else ext
        timeout = cast(TimeoutDict, ext.get("timeout", {}))

        async with self.request_lock:
            if self.state == ConnectionState.PENDING:
                if not self.socket:
                    logger.trace("open_socket origin=%r timeout=%r",
                                 self.origin, timeout)
                    self.socket = await self._open_socket(timeout)
                self._create_connection(self.socket)
            elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
                pass
            elif self.state == ConnectionState.ACTIVE and self.is_http2:
                pass
            else:
                raise NewConnectionRequired()

        assert self.connection is not None
        logger.trace("connection.arequest method=%r url=%r headers=%r", method,
                     url, headers)
        return await self.connection.arequest(method, url, headers, stream,
                                              ext)

    async def _open_socket(self,
                           timeout: TimeoutDict = None) -> AsyncSocketStream:
        scheme, hostname, port = self.origin
        timeout = {} if timeout is None else timeout
        ssl_context = self.ssl_context if scheme == b"https" else None

        retries_left = self.retries
        delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

        while True:
            try:
                if self.uds is None:
                    return await self.backend.open_tcp_stream(
                        hostname,
                        port,
                        ssl_context,
                        timeout,
                        local_address=self.local_address,
                    )
                else:
                    return await self.backend.open_uds_stream(
                        self.uds, hostname, ssl_context, timeout)
            except (ConnectError, ConnectTimeout):
                if retries_left <= 0:
                    self.connect_failed = True
                    raise
                retries_left -= 1
                delay = next(delays)
                await self.backend.sleep(delay)
            except Exception:  # noqa: PIE786
                self.connect_failed = True
                raise

    def _create_connection(self, socket: AsyncSocketStream) -> None:
        http_version = socket.get_http_version()
        logger.trace("create_connection socket=%r http_version=%r", socket,
                     http_version)
        if http_version == "HTTP/2":
            from .http2 import AsyncHTTP2Connection

            self.is_http2 = True
            self.connection = AsyncHTTP2Connection(
                socket=socket,
                backend=self.backend,
                ssl_context=self.ssl_context)
        else:
            from .http11 import AsyncHTTP11Connection

            self.is_http11 = True
            self.connection = AsyncHTTP11Connection(
                socket=socket, ssl_context=self.ssl_context)

    @property
    def state(self) -> ConnectionState:
        if self.connect_failed:
            return ConnectionState.CLOSED
        elif self.connection is None:
            return ConnectionState.PENDING
        return self.connection.get_state()

    def is_connection_dropped(self) -> bool:
        return self.connection is not None and self.connection.is_connection_dropped(
        )

    def mark_as_ready(self) -> None:
        if self.connection is not None:
            self.connection.mark_as_ready()

    async def start_tls(self,
                        hostname: bytes,
                        timeout: TimeoutDict = None) -> None:
        if self.connection is not None:
            logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
            self.socket = await self.connection.start_tls(hostname, timeout)
            logger.trace("start_tls complete hostname=%r timeout=%r", hostname,
                         timeout)

    async def aclose(self) -> None:
        async with self.request_lock:
            if self.connection is not None:
                await self.connection.aclose()
Exemplo n.º 8
0
class AsyncHTTPConnection(AsyncHTTPTransport):
    def __init__(
        self,
        origin: Tuple[bytes, bytes, int],
        http2: bool = False,
        ssl_context: SSLContext = None,
    ):
        self.origin = origin
        self.http2 = http2
        self.ssl_context = SSLContext() if ssl_context is None else ssl_context

        if self.http2:
            self.ssl_context.set_alpn_protocols(["http/1.1", "h2"])

        self.connection: Union[None, AsyncHTTP11Connection,
                               AsyncHTTP2Connection] = None
        self.is_http11 = False
        self.is_http2 = False
        self.connect_failed = False
        self.expires_at: Optional[float] = None
        self.backend = AutoBackend()

    @property
    def request_lock(self) -> AsyncLock:
        # We do this lazily, to make sure backend autodetection always
        # runs within an async context.
        if not hasattr(self, "_request_lock"):
            self._request_lock = self.backend.create_lock()
        return self._request_lock

    async def request(
        self,
        method: bytes,
        url: Tuple[bytes, bytes, int, bytes],
        headers: List[Tuple[bytes, bytes]] = None,
        stream: AsyncByteStream = None,
        timeout: Dict[str, Optional[float]] = None,
    ) -> Tuple[bytes, int, bytes, List[Tuple[bytes, bytes]], AsyncByteStream]:
        assert url[:3] == self.origin

        async with self.request_lock:
            if self.state == ConnectionState.PENDING:
                try:
                    await self._connect(timeout)
                except:
                    self.connect_failed = True
                    raise
            elif self.state in (ConnectionState.READY, ConnectionState.IDLE):
                pass
            elif self.state == ConnectionState.ACTIVE and self.is_http2:
                pass
            else:
                raise NewConnectionRequired()

        assert self.connection is not None
        return await self.connection.request(method, url, headers, stream,
                                             timeout)

    async def _connect(
        self,
        timeout: Dict[str, Optional[float]] = None,
    ):
        scheme, hostname, port = self.origin
        timeout = {} if timeout is None else timeout
        ssl_context = self.ssl_context if scheme == b"https" else None
        socket = await self.backend.open_tcp_stream(hostname, port,
                                                    ssl_context, timeout)
        http_version = socket.get_http_version()
        if http_version == "HTTP/2":
            self.is_http2 = True
            self.connection = AsyncHTTP2Connection(socket=socket,
                                                   backend=self.backend)
        else:
            self.is_http11 = True
            self.connection = AsyncHTTP11Connection(socket=socket)

    @property
    def state(self) -> ConnectionState:
        if self.connect_failed:
            return ConnectionState.CLOSED
        elif self.connection is None:
            return ConnectionState.PENDING
        return self.connection.state

    def is_connection_dropped(self) -> bool:
        return self.connection is not None and self.connection.is_connection_dropped(
        )

    def mark_as_ready(self) -> None:
        if self.connection is not None:
            self.connection.mark_as_ready()

    async def start_tls(self,
                        hostname: bytes,
                        timeout: Dict[str, Optional[float]] = None):
        if self.connection is not None:
            await self.connection.start_tls(hostname, timeout)
Exemplo n.º 9
0
class SyncHTTPConnection(SyncHTTPTransport):
    def __init__(
        self,
        origin: Origin,
        http1: bool = True,
        http2: bool = False,
        keepalive_expiry: float = None,
        uds: str = None,
        ssl_context: SSLContext = None,
        socket: SyncSocketStream = None,
        local_address: str = None,
        retries: int = 0,
        backend: SyncBackend = None,
    ):
        self.origin = origin
        self._http1_enabled = http1
        self._http2_enabled = http2
        self._keepalive_expiry = keepalive_expiry
        self._uds = uds
        self._ssl_context = SSLContext(
        ) if ssl_context is None else ssl_context
        self.socket = socket
        self._local_address = local_address
        self._retries = retries

        alpn_protocols: List[str] = []
        if http1:
            alpn_protocols.append("http/1.1")
        if http2:
            alpn_protocols.append("h2")

        self._ssl_context.set_alpn_protocols(alpn_protocols)

        self.connection: Optional[SyncBaseHTTPConnection] = None
        self._is_http11 = False
        self._is_http2 = False
        self._connect_failed = False
        self._expires_at: Optional[float] = None
        self._backend = SyncBackend() if backend is None else backend

    def __repr__(self) -> str:
        return f"<SyncHTTPConnection [{self.info()}]>"

    def info(self) -> str:
        if self.connection is None:
            return "Connection failed" if self._connect_failed else "Connecting"
        return self.connection.info()

    def should_close(self) -> bool:
        """
        Return `True` if the connection is in a state where it should be closed.
        This occurs when any of the following occur:

        * There are no active requests on an HTTP/1.1 connection, and the underlying
          socket is readable. The only valid state the socket can be readable in
          if this occurs is when the b"" EOF marker is about to be returned,
          indicating a server disconnect.
        * There are no active requests being made and the keepalive timeout has passed.
        """
        if self.connection is None:
            return False
        return self.connection.should_close()

    def is_idle(self) -> bool:
        """
        Return `True` if the connection is currently idle.
        """
        if self.connection is None:
            return False
        return self.connection.is_idle()

    def is_closed(self) -> bool:
        if self.connection is None:
            return self._connect_failed
        return self.connection.is_closed()

    def is_available(self) -> bool:
        """
        Return `True` if the connection is currently able to accept an outgoing request.
        This occurs when any of the following occur:

        * The connection has not yet been opened, and HTTP/2 support is enabled.
          We don't *know* at this point if we'll end up on an HTTP/2 connection or
          not, but we *might* do, so we indicate availability.
        * The connection has been opened, and is currently idle.
        * The connection is open, and is an HTTP/2 connection. The connection must
          also not currently be exceeding the maximum number of allowable concurrent
          streams and must not have exhausted the maximum total number of stream IDs.
        """
        if self.connection is None:
            return self._http2_enabled and not self.is_closed
        return self.connection.is_available()

    @property
    def request_lock(self) -> SyncLock:
        # We do this lazily, to make sure backend autodetection always
        # runs within an async context.
        if not hasattr(self, "_request_lock"):
            self._request_lock = self._backend.create_lock()
        return self._request_lock

    def handle_request(
        self,
        method: bytes,
        url: URL,
        headers: Headers,
        stream: SyncByteStream,
        extensions: dict,
    ) -> Tuple[int, Headers, SyncByteStream, dict]:
        assert url_to_origin(url) == self.origin
        timeout = cast(TimeoutDict, extensions.get("timeout", {}))

        with self.request_lock:
            if self.connection is None:
                if self._connect_failed:
                    raise NewConnectionRequired()
                if not self.socket:
                    logger.trace("open_socket origin=%r timeout=%r",
                                 self.origin, timeout)
                    self.socket = self._open_socket(timeout)
                self._create_connection(self.socket)
            elif not self.connection.is_available():
                raise NewConnectionRequired()

        assert self.connection is not None
        logger.trace(
            "connection.handle_request method=%r url=%r headers=%r",
            method,
            url,
            headers,
        )
        return self.connection.handle_request(method, url, headers, stream,
                                              extensions)

    def _open_socket(self, timeout: TimeoutDict = None) -> SyncSocketStream:
        scheme, hostname, port = self.origin
        timeout = {} if timeout is None else timeout
        ssl_context = self._ssl_context if scheme == b"https" else None

        retries_left = self._retries
        delays = exponential_backoff(factor=RETRIES_BACKOFF_FACTOR)

        while True:
            try:
                if self._uds is None:
                    return self._backend.open_tcp_stream(
                        hostname,
                        port,
                        ssl_context,
                        timeout,
                        local_address=self._local_address,
                    )
                else:
                    return self._backend.open_uds_stream(
                        self._uds, hostname, ssl_context, timeout)
            except (ConnectError, ConnectTimeout):
                if retries_left <= 0:
                    self._connect_failed = True
                    raise
                retries_left -= 1
                delay = next(delays)
                self._backend.sleep(delay)
            except Exception:  # noqa: PIE786
                self._connect_failed = True
                raise

    def _create_connection(self, socket: SyncSocketStream) -> None:
        http_version = socket.get_http_version()
        logger.trace("create_connection socket=%r http_version=%r", socket,
                     http_version)
        if http_version == "HTTP/2" or (self._http2_enabled
                                        and not self._http1_enabled):
            from .http2 import SyncHTTP2Connection

            self._is_http2 = True
            self.connection = SyncHTTP2Connection(
                socket=socket,
                keepalive_expiry=self._keepalive_expiry,
                backend=self._backend,
            )
        else:
            self._is_http11 = True
            self.connection = SyncHTTP11Connection(
                socket=socket, keepalive_expiry=self._keepalive_expiry)

    def start_tls(self,
                  hostname: bytes,
                  ssl_context: SSLContext,
                  timeout: TimeoutDict = None) -> None:
        if self.connection is not None:
            logger.trace("start_tls hostname=%r timeout=%r", hostname, timeout)
            self.socket = self.connection.start_tls(hostname, ssl_context,
                                                    timeout)
            logger.trace("start_tls complete hostname=%r timeout=%r", hostname,
                         timeout)

    def close(self) -> None:
        with self.request_lock:
            if self.connection is not None:
                self.connection.close()