def test_delete_service(
    mocked_director_v0_service_api: MockRouter,
    mocked_director_v2_scheduler: None,
    client: TestClient,
    service: Dict[str, Any],
    exp_status_code: int,
    is_legacy: bool,
    can_save: Optional[bool],
    exp_save_state: bool,
):
    url = URL(f"/v2/dynamic_services/{service['node_uuid']}")
    if can_save is not None:
        url = url.copy_with(params={"can_save": can_save})

    response = client.delete(str(url), allow_redirects=False)
    assert (
        response.status_code == exp_status_code
    ), f"expected status code {exp_status_code}, received {response.status_code}: {response.text}"
    if exp_status_code == status.HTTP_307_TEMPORARY_REDIRECT:
        # check redirection header goes to director-v0
        assert "location" in response.headers
        redirect_url = URL(response.headers["location"])
        assert redirect_url.host == "director"
        assert (redirect_url.path ==
                f"/v0/running_interactive_services/{service['node_uuid']}")
        assert redirect_url.params == QueryParams(can_save=exp_save_state)
Beispiel #2
0
def update_query(url: httpx.URL, **new_params: Optional[str]) -> httpx.URL:
    """
    Adds, replaces, or removes query parameters. For each keyword argument, all
    query parameters with that name are removed; then, if the argument is not
    None, it's added as the sole new value of the corresponding query
    parameter.

    >>> str(update_query(httpx.URL("http://example.org/"), yes="1", no=None))
    'http://example.org/?yes=1'
    >>> str(update_query(httpx.URL("http://example.org/?yes=0&no="), yes="1", no=None))
    'http://example.org/?yes=1'

    Additionally, this function puts the URL into a canonical form to increase
    cache hit rates. A generic client can't make any assumptions about how the
    query string is used, but because we only care how WordPress would
    interpret it, we can safely de-duplicate and sort the parameters.

    >>> str(update_query(httpx.URL("http://example.org/?yes=2&yes=1&no=0")))
    'http://example.org/?no=0&yes=1'
    """

    params = dict(parse_qsl(url.query))
    for k, v in new_params.items():
        if v is None:
            params.pop(k.encode(), None)
        else:
            params[k.encode()] = v.encode()
    query = urlencode(sorted(params.items())).encode()
    return url.copy_with(query=query)
Beispiel #3
0
def test_url_copywith_for_userinfo():
    copy_with_kwargs = {
        "username": "******",
        "password": "******",
    }
    url = URL("https://example.org")
    new = url.copy_with(**copy_with_kwargs)
    assert str(new) == "https://tom%40example.org:abc123%40%20%[email protected]"
    assert new.username == "*****@*****.**"
    assert new.password == "abc123@ %"
Beispiel #4
0
def test_url_copywith_for_authority():
    copy_with_kwargs = {
        "username": "******",
        "password": "******",
        "port": 444,
        "host": "example.net",
    }
    url = URL("https://example.org")
    new = url.copy_with(**copy_with_kwargs)
    for k, v in copy_with_kwargs.items():
        assert getattr(new, k) == v
    assert str(new) == "https://*****:*****@example.net:444"
Beispiel #5
0
def test_url():
    url = URL("https://example.org:123/path/to/somewhere?abc=123#anchor")
    assert url.scheme == "https"
    assert url.host == "example.org"
    assert url.port == 123
    assert url.authority == "example.org:123"
    assert url.path == "/path/to/somewhere"
    assert url.query == "abc=123"
    assert url.fragment == "anchor"
    assert (repr(url) ==
            "URL('https://example.org:123/path/to/somewhere?abc=123#anchor')")

    new = url.copy_with(scheme="http", port=None)
    assert new == URL("http://example.org/path/to/somewhere?abc=123#anchor")
    assert new.scheme == "http"
Beispiel #6
0
    def _merge_url(self, url):
        if isinstance(url, str):
            if url.startswith('http://localhost'):
                return url
            url = URL(url)

        new_url = url.copy_with(scheme=self.scheme, host=self.host, port=self.port)
        if 's3.' in url.host:
            return new_url.copy_with(path='/s3/')
        elif 'email.' in url.host:
            return new_url.copy_with(path='/ses/')
        elif url.host.startswith('sns.'):
            if 'bad' in url.path:
                return new_url.copy_with(path='/status/400/')
            elif url.path.endswith('.pem'):
                return new_url.copy_with(path='/sns/certs/')
            else:
                return new_url.copy_with(path='/status/200/')
        else:
            # return url
            raise ValueError(f'no local endpoint found for "{url}"')
Beispiel #7
0
 def _ensure_path(self, url: httpx.URL) -> httpx.URL:
     if not url._uri_reference.path:
         url = url.copy_with(path="/")
     return url
Beispiel #8
0
class Service:
    """
    The service object to facilitate sending requests to other services.

    :param service_name: The name of the service to send a request to.
    :param partial_path: Base path of the endpoint to send to.
    """

    _client = None

    def __init__(self, service_name: str, partial_path: str = None):

        self.service_name = service_name
        self.service_token = jwt_service_encode_handler(
            jwt_service_payload_handler(self))

        partial_path = partial_path or "/"
        self.url = URL(
            f"{settings.SERVICE_GLOBAL_SCHEMA}://"
            f"{settings.SERVICE_GLOBAL_HOST_TEMPLATE.format(self.service_name)}"
            f":{settings.SERVICE_GLOBAL_PORT}{partial_path}")
        super().__init__()

    @property
    def client(self) -> AsyncClient:
        """
        The httpx.AsyncClient that will be used to send async requests.
        """
        if self._client is None:
            limits = Limits(
                max_connections=settings.SERVICE_CONNECTOR_MAX,
                max_keepalive_connections=settings.
                SERVICE_CONNECTOR_MAX_KEEPALIVE,
            )
            timeout = Timeout(
                timeout=settings.SERVICE_TIMEOUT_TOTAL,
                connect=getattr(settings, "SERVICE_TIMEOUT_CONNECT", UNSET),
                read=getattr(settings, "SERVICE_TIMEOUT_READ", UNSET),
                write=getattr(settings, "SERVICE_TIMEOUT_WRITE", UNSET),
                pool=getattr(settings, "SERVICE_TIMEOUT_POOL", UNSET),
            )

            self._client = AsyncClient(
                limits=limits,
                timeout=timeout,
                base_url=self.url,
                headers={
                    "authorization":
                    f"{settings.JWT_SERVICE_AUTH_AUTH_HEADER_PREFIX} {self.service_token}"
                },
            )
        return self._client

    @property
    def host(self) -> str:
        """
        The host portion of the url.
        """
        return self.url.host

    @host.setter
    def host(self, value: str) -> None:
        self.url = self.url.copy_with(host=value)

    @property
    def port(self) -> int:
        """
        The port portion of the url.
        """
        return self.url.port

    @port.setter
    def port(self, value: int) -> None:
        self.url = self.url.copy_with(port=value)

    async def close_client(self) -> None:
        """
        Close the async client on shutdown.
        """
        if self._client is not None:
            await self._client.aclose()
            await asyncio.sleep(0)

    def _inject_headers(self, headers: dict):
        # need to coerce to str
        headers = {k: str(v) for k, v in headers.items()}
        headers.update(
            {"date": get_utc_datetime().strftime("%a, %d %b %y %T %z")})

        # inject user information to request headers
        user = context_user()
        headers.update({
            settings.INTERNAL_REQUEST_USER_HEADER.lower():
            to_header_value(user)
        })
        # inject correlation_id
        correlation_id = context_correlation_id()
        headers.update(
            {settings.REQUEST_ID_HEADER_FIELD.lower(): correlation_id})

        return Headers(headers)

    def http_dispatch(
        self,
        method: str,
        endpoint: str,
        *,
        query_params: dict = None,
        payload: dict = None,
        files: dict = None,
        headers: dict = None,
        propagate_error: bool = False,
        include_status_code: bool = False,
        response_timeout: int = UNSET,
        retry_count: int = None,
        **kwargs,
    ):
        """
        Interface for sending requests to other services.

        :param method: method to send request (GET, POST, PATCH, PUT, etc)
        :param endpoint: the path to send request to (eg /api/v1/..)
        :param query_params: query params to attach to url
        :param payload: the data to send on any non GET requests
        :param files: if any files to send with request, must be included here
        :param headers: headers to send along with request
        :param propagate_error: if you want to raise on 400 or greater status codes
        :param include_status_code: if you want this method to return the response with the status code
        :param response_timeout: if you want to increase the timeout for this requests
        :param retry_count: number times you want to retry the request if failed on server errors
        """

        files = files or {}
        query_params = query_params or {}
        payload = payload or {}
        headers = self._inject_headers(headers or {})

        request = self.client.build_request(
            method,
            endpoint,
            data=payload,
            files=files,
            params=query_params,
            headers=headers,
        )

        return asyncio.ensure_future(
            self._dispatch_future(
                request,
                propagate_error=propagate_error,
                response_timeout=response_timeout,
                include_status_code=include_status_code,
                retry_count=retry_count,
                **kwargs,
            ))

    async def _dispatch_future(
        self,
        request,
        *,
        propagate_error: bool = False,
        response_timeout: float = None,
        include_status_code: bool = False,
        retry_count: int = None,
        **kwargs,
    ):
        """
        The async method that wraps the actual fetch

        :param request:
        :param propagate_error:
        :param response_timeout:
        :param include_status_code:
        :param retry_count:
        :param kwargs:
        :return:
        """

        try:
            resp = await asyncio.shield(
                self._dispatch_send(
                    request=request,
                    timeout=response_timeout,
                    retry_count=retry_count,
                ))

            if propagate_error:
                resp.raise_for_status()
            response = resp.json()

            if include_status_code:
                return response, resp.status_code
            else:
                return response
        except HTTPStatusError as e:
            try:
                response = e.response.json()
            except JSONDecodeError:
                response = e.response.text
                description = response
                error_code = GlobalErrorCodes.unknown_error
                message = error_code.name
            else:
                description = response.get("description", response)
                error_code = response.get("error_code",
                                          GlobalErrorCodes.unknown_error)
                message = response.get("message", error_code.name)

            exc = exceptions.APIException(
                description=description,
                error_code=error_code,
                status_code=e.response.status_code,
            )
            exc.message = message

            base_error_message = (
                f"ClientResponseError: {e.request.method} {e.request.url} {e.response.status_code} "
                f"{e}")

            if StatusCode.is_server_error(e.response.status_code):
                error_logger.error(base_error_message)
            else:
                error_logger.info(base_error_message)

            raise exc
        except httpx.TimeoutException as e:
            exc = exceptions.ResponseTimeoutError(
                description=str(e),
                error_code=GlobalErrorCodes.service_timeout,
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            )
            raise exc

        except InvalidURL as e:
            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.invalid_url,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
            raise exc
        except NotRedirectResponse as e:

            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.invalid_url,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
            raise exc
        except CookieConflict as e:
            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.client_payload_error,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
            raise exc
        except StreamError as e:

            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.stream_error,
                status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            )
            raise exc

        except RequestError as e:
            if hasattr(e, "response"):
                status_code = getattr(
                    e.response,
                    "status_code",
                    status.HTTP_503_SERVICE_UNAVAILABLE,
                )
            else:
                status_code = status.HTTP_503_SERVICE_UNAVAILABLE

            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.service_unavailable,
                status_code=status_code,
            )
            raise exc
        except HTTPError as e:

            exc = exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.transport_error,
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            )
            raise exc
        except socket.gaierror as e:
            raise exceptions.APIException(
                description=str(e),
                error_code=GlobalErrorCodes.service_unavailable,
                status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
            )

    async def _dispatch_send(
        self,
        request: Request,
        *,
        timeout: float = None,
        retry_count: int = None,
    ):
        """

        TODO: need better implementation for retry

        :param request:
        :param timeout:
        :param retry_count:
        """
        attempts = 1
        if request.method == "GET":
            attempts += (settings.SERVICE_CONNECTION_DEFAULT_RETRY_COUNT
                         if retry_count is None else min(
                             retry_count,
                             int(settings.SERVICE_CONNECTION_MAX_RETRY_COUNT),
                         ))

        for i in range(attempts):
            try:
                response = await self.client.send(request, timeout=timeout)

                if codes.is_server_error(response.status_code):
                    response.raise_for_status()

            except (TransportError, HTTPStatusError,
                    ConnectionResetError) as e:
                error_logger.debug(f"{str(e)} on attempt {i}")
                if i + 1 >= attempts:
                    raise
            else:
                return response