示例#1
0
def test_resolve_request_headers_no_arg(mock_request_header_providers):
    assert resolve_request_headers() == {
        "one": "override",
        "two": "two-val",
        "three": "three-val",
        "new": "new-val",
    }
示例#2
0
def test_resolve_request_headers(mock_request_header_providers):
    request_headers_arg = {"two": "arg-override", "arg": "arg-val"}
    assert resolve_request_headers(request_headers_arg) == {
        "one": "override",
        "two": "arg-override",
        "three": "three-val",
        "new": "new-val",
        "arg": "arg-val",
    }
示例#3
0
def test_run_context_provider_registry_with_installed_plugin():
    """This test requires the package in tests/resources/mlflow-test-plugin to be installed"""

    reload(mlflow.tracking.request_header.registry)

    from mlflow_test_plugin.request_header_provider import PluginRequestHeaderProvider

    assert PluginRequestHeaderProvider in _currently_registered_request_header_provider_classes(
    )

    # The test plugin's request header provider always returns False from in_context to avoid
    # polluting request headers in developers' environments. The following mock overrides this to
    # perform the integration test.
    with mock.patch.object(PluginRequestHeaderProvider,
                           "in_context",
                           return_value=True):
        assert resolve_request_headers()["test"] == "header"
示例#4
0
def http_request(
    host_creds,
    endpoint,
    method,
    max_retries=5,
    backoff_factor=2,
    retry_codes=_TRANSIENT_FAILURE_RESPONSE_CODES,
    timeout=120,
    **kwargs,
):
    """
    Makes an HTTP request with the specified method to the specified hostname/endpoint. Transient
    errors such as Rate-limited (429), service unavailable (503) and internal error (500) are
    retried with an exponential back off with backoff_factor * (1, 2, 4, ... seconds).
    The function parses the API response (assumed to be JSON) into a Python object and returns it.

    :param host_creds: A :py:class:`mlflow.rest_utils.MlflowHostCreds` object containing
        hostname and optional authentication.
    :param endpoint: a string for service endpoint, e.g. "/path/to/object".
    :param method: a string indicating the method to use, e.g. "GET", "POST", "PUT".
    :param max_retries: maximum number of retries before throwing an exception.
    :param backoff_factor: a time factor for exponential backoff. e.g. value 5 means the HTTP
      request will be retried with interval 5, 10, 20... seconds. A value of 0 turns off the
      exponential backoff.
    :param retry_codes: a list of HTTP response error codes that qualifies for retry.
    :param timeout: wait for timeout seconds for response from remote server for connect and
      read request.
    :param kwargs: Additional keyword arguments to pass to `requests.Session.request()`

    :return: requests.Response object.
    """
    hostname = host_creds.host
    auth_str = None
    if host_creds.username and host_creds.password:
        basic_auth_str = (
            "%s:%s" %
            (host_creds.username, host_creds.password)).encode("utf-8")
        auth_str = "Basic " + base64.standard_b64encode(basic_auth_str).decode(
            "utf-8")
    elif host_creds.token:
        auth_str = "Bearer %s" % host_creds.token

    from mlflow.tracking.request_header.registry import resolve_request_headers

    headers = dict({**_DEFAULT_HEADERS, **resolve_request_headers()})
    if auth_str:
        headers["Authorization"] = auth_str

    if host_creds.server_cert_path is None:
        verify = not host_creds.ignore_tls_verification
    else:
        verify = host_creds.server_cert_path

    if host_creds.client_cert_path is not None:
        kwargs["cert"] = host_creds.client_cert_path

    cleaned_hostname = strip_suffix(hostname, "/")
    url = "%s%s" % (cleaned_hostname, endpoint)
    try:
        return _get_http_response_with_retries(
            method,
            url,
            max_retries,
            backoff_factor,
            retry_codes,
            headers=headers,
            verify=verify,
            timeout=timeout,
            **kwargs,
        )
    except Exception as e:
        raise MlflowException("API request to %s failed with exception %s" %
                              (url, e))
示例#5
0
def http_request(host_creds,
                 endpoint,
                 retries=3,
                 retry_interval=3,
                 max_rate_limit_interval=60,
                 **kwargs):
    """
    Makes an HTTP request with the specified method to the specified hostname/endpoint. Ratelimit
    error code (429) will be retried with an exponential back off (1, 2, 4, ... seconds) for at most
    `max_rate_limit_interval` seconds.  Internal errors (500s) will be retried up to `retries` times
    , waiting `retry_interval` seconds between successive retries. Parses the API response
    (assumed to be JSON) into a Python object and returns it.

    :param host_creds: A :py:class:`mlflow.rest_utils.MlflowHostCreds` object containing
        hostname and optional authentication.
    :return: Parsed API response
    """
    hostname = host_creds.host
    auth_str = None
    if host_creds.username and host_creds.password:
        basic_auth_str = (
            "%s:%s" %
            (host_creds.username, host_creds.password)).encode("utf-8")
        auth_str = "Basic " + base64.standard_b64encode(basic_auth_str).decode(
            "utf-8")
    elif host_creds.token:
        auth_str = "Bearer %s" % host_creds.token

    from mlflow.tracking.request_header.registry import resolve_request_headers

    headers = dict({**_DEFAULT_HEADERS, **resolve_request_headers()})
    if auth_str:
        headers["Authorization"] = auth_str

    if host_creds.server_cert_path is None:
        verify = not host_creds.ignore_tls_verification
    else:
        verify = host_creds.server_cert_path

    if host_creds.client_cert_path is not None:
        kwargs["cert"] = host_creds.client_cert_path

    def request_with_ratelimit_retries(max_rate_limit_interval, **kwargs):
        response = requests.request(**kwargs)
        time_left = max_rate_limit_interval
        sleep = 1
        while response.status_code == 429 and time_left > 0:
            _logger.warning(
                "API request to {path} returned status code 429 (Rate limit exceeded). "
                "Retrying in %d seconds. "
                "Will continue to retry 429s for up to %d seconds.",
                sleep,
                time_left,
            )
            time.sleep(sleep)
            time_left -= sleep
            response = requests.request(**kwargs)
            sleep = min(time_left,
                        sleep * 2)  # sleep for 1, 2, 4, ... seconds;
        return response

    cleaned_hostname = strip_suffix(hostname, "/")
    url = "%s%s" % (cleaned_hostname, endpoint)
    for i in range(retries):
        response = request_with_ratelimit_retries(max_rate_limit_interval,
                                                  url=url,
                                                  headers=headers,
                                                  verify=verify,
                                                  **kwargs)
        if response.status_code >= 200 and response.status_code < 500:
            return response
        else:
            _logger.error(
                "API request to %s failed with code %s != 200, retrying up to %s more times. "
                "API response body: %s",
                url,
                response.status_code,
                retries - i - 1,
                response.text,
            )
            time.sleep(retry_interval)
    raise MlflowException(
        "API request to %s failed to return code 200 after %s tries" %
        (url, retries))