def _copy_artifact(local_file, artifact_uri, artifact_path=None):
    basename = os.path.basename(local_file)
    if artifact_path:
        http_endpoint = _get_dbfs_endpoint(
            artifact_uri, posixpath.join(artifact_path, basename))
    else:
        http_endpoint = _get_dbfs_endpoint(artifact_uri, basename)

    host_creds = get_databricks_host_creds('registry')
    print("Copying file to " + http_endpoint + " in registry workspace")
    try:
        if os.stat(local_file).st_size == 0:
            # The API frontend doesn't like it when we post empty files to it using
            # `requests.request`, potentially due to the bug described in
            # https://github.com/requests/requests/issues/4215
            http_request_safe(host_creds,
                              endpoint=http_endpoint,
                              method='POST',
                              data="",
                              allow_redirects=False)
        else:
            with open(local_file, 'rb') as f:
                http_request_safe(host_creds,
                                  endpoint=http_endpoint,
                                  method='POST',
                                  data=f,
                                  allow_redirects=False)
    except MlflowException as e:
        # Note: instead of catching the error here, we could check for the existence of file before trying the copy.
        if "File already exists" in e.message:
            print("File already exists - continuing to the next file.")
            import time
        else:
            throw(e)
Esempio n. 2
0
def test_http_request_wrapper(request):
    host_only = MlflowHostCreds("http://my-host", ignore_tls_verification=True)
    response = mock.MagicMock()
    response.status_code = 200
    response.text = "{}"
    request.return_value = response
    http_request_safe(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint",
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
    response.text = "non json"
    request.return_value = response
    http_request_safe(host_only, "/my/endpoint")
    request.assert_called_with(
        url="http://my-host/my/endpoint",
        verify=False,
        headers=_DEFAULT_HEADERS,
    )
    response.status_code = 400
    response.text = ""
    request.return_value = response
    with pytest.raises(MlflowException, match="Response body"):
        http_request_safe(host_only, "/my/endpoint")
    response.text = (
        '{"error_code": "RESOURCE_DOES_NOT_EXIST", "message": "Node type not supported"}'
    )
    request.return_value = response
    with pytest.raises(
            RestException,
            match="RESOURCE_DOES_NOT_EXIST: Node type not supported"):
        http_request_safe(host_only, "/my/endpoint")
Esempio n. 3
0
 def _databricks_api_request(self, endpoint, method, **kwargs):
     host_creds = databricks_utils.get_databricks_host_creds(
         self.databricks_profile)
     return rest_utils.http_request_safe(host_creds=host_creds,
                                         endpoint=endpoint,
                                         method=method,
                                         **kwargs)
Esempio n. 4
0
    def _call_endpoint(self, api, json_body):
        endpoint, method = _METHOD_TO_INFO[api]
        response_proto = api.Response()
        # Convert json string to json dictionary, to pass to requests
        if json_body:
            json_body = json.loads(json_body)
        host_creds = self.get_host_creds()

        if method == 'GET':
            response = http_request_safe(
                host_creds=host_creds, endpoint=endpoint, method=method, params=json_body)
        else:
            response = http_request_safe(
                host_creds=host_creds, endpoint=endpoint, method=method, json=json_body)

        js_dict = json.loads(response.text)
        parse_dict(js_dict=js_dict, message=response_proto)
        return response_proto
Esempio n. 5
0
 def _databricks_api_request(self, endpoint, **kwargs):
     host_creds = self.get_host_creds()
     return http_request_safe(host_creds=host_creds, endpoint=endpoint, **kwargs)