Пример #1
0
    def test_log_artifacts_provides_failure_info(self, databricks_artifact_repo, tmpdir):
        src_file1_path = os.path.join(str(tmpdir), "file_1.txt")
        with open(src_file1_path, "w") as f:
            f.write("file1")
        src_file2_path = os.path.join(str(tmpdir), "file_2.txt")
        with open(src_file2_path, "w") as f:
            f.write("file2")

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
        ) as write_credentials_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud"
        ) as upload_mock:
            write_credentials_mock.return_value = [
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
            ]
            upload_mock.side_effect = [
                MlflowException("MOCK ERROR 1"),
                MlflowException("MOCK ERROR 2"),
            ]

            with pytest.raises(MlflowException) as exc:
                databricks_artifact_repo.log_artifacts(str(tmpdir), "test_artifacts")

            err_msg = str(exc.value)
            assert MOCK_RUN_ROOT_URI in err_msg
            assert "file_1.txt" in err_msg
            assert "MOCK ERROR 1" in err_msg
            assert "file_2.txt" in err_msg
            assert "MOCK ERROR 2" in err_msg
Пример #2
0
    def test_get_read_credential_infos_respects_max_request_size(self, databricks_artifact_repo):
        """
        Verifies that the `_get_read_credential_infos` method, which is used to resolve read access
        credentials for a collection of artifacts, handles paginated responses properly, issuing
        incremental requests until all pages have been consumed
        """
        assert _MAX_CREDENTIALS_REQUEST_SIZE == 2000, (
            "The maximum request size configured by the client should be consistent with the"
            " Databricks backend. Only update this value of the backend limit has changed."
        )

        # Create 3 chunks of paths, two of which have the maximum request size and one of which
        # is smaller than the maximum chunk size. Aggregate and pass these to
        # `_get_read_credential_infos`, validating that this method decomposes the aggregate
        # list into these expected chunks and makes 3 separate requests
        paths_chunk_1 = ["path1"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_2 = ["path2"] * _MAX_CREDENTIALS_REQUEST_SIZE
        paths_chunk_3 = ["path3"] * 5
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(_MAX_CREDENTIALS_REQUEST_SIZE)
        ]
        credential_infos_mock_3 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(5)
        ]

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            call_endpoint_mock.side_effect = [
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_1,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_2,),
                GetCredentialsForRead.Response(credential_infos=credential_infos_mock_3,),
            ]

            databricks_artifact_repo._get_read_credential_infos(
                MOCK_RUN_ID, paths_chunk_1 + paths_chunk_2 + paths_chunk_3,
            )
            assert call_endpoint_mock.call_count == 3
            assert message_mock.call_count == 3
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_1)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_2)),
                    mock.call(GetCredentialsForRead(run_id=MOCK_RUN_ID, path=paths_chunk_3)),
                ]
            )
Пример #3
0
    def test_download_artifacts_provides_failure_info(self, databricks_artifact_repo):
        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
        ) as read_credentials_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
        ) as get_list_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
        ) as download_mock:
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]
            read_credentials_mock.return_value = [
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
            ]
            download_mock.side_effect = [
                MlflowException("MOCK ERROR 1"),
                MlflowException("MOCK ERROR 2"),
            ]

            with pytest.raises(MlflowException) as exc:
                databricks_artifact_repo.download_artifacts("test_path")

            err_msg = str(exc.value)
            assert MOCK_RUN_ROOT_URI in err_msg
            assert "file_1.txt" in err_msg
            assert "MOCK ERROR 1" in err_msg
            assert "file_2.txt" in err_msg
            assert "MOCK ERROR 2" in err_msg
    def test_get_write_credential_infos_handles_pagination(self, databricks_artifact_repo):
        """
        Verifies that the `_get_write_credential_infos` method, which is used to resolve write
        access credentials for a collection of artifacts, handles paginated responses properly,
        issuing incremental requests until all pages have been consumed
        """
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            ),
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
        ]
        credential_infos_mock_3 = []

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            get_credentials_for_write_responses = [
                GetCredentialsForWrite.Response(
                    credential_infos=credential_infos_mock_1, next_page_token="2"
                ),
                GetCredentialsForWrite.Response(
                    credential_infos=credential_infos_mock_2, next_page_token="3"
                ),
                GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3),
            ]
            call_endpoint_mock.side_effect = get_credentials_for_write_responses
            write_credential_infos = databricks_artifact_repo._get_write_credential_infos(
                MOCK_RUN_ID,
                ["testpath"],
            )
            assert write_credential_infos == credential_infos_mock_1 + credential_infos_mock_2
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=["testpath"])),
                    mock.call(
                        GetCredentialsForWrite(
                            run_id=MOCK_RUN_ID, path=["testpath"], page_token="2"
                        )
                    ),
                    mock.call(
                        GetCredentialsForWrite(
                            run_id=MOCK_RUN_ID, path=["testpath"], page_token="3"
                        )
                    ),
                ]
            )
            assert call_endpoint_mock.call_count == 3
    def test_get_write_credential_infos_respects_max_request_size(self, databricks_artifact_repo):
        """
        Verifies that the `_get_write_credential_infos` method, which is used to resolve write
        access credentials for a collection of artifacts, batches requests according to a maximum
        request size configured by the backend
        """
        # Create 3 chunks of paths, two of which have the maximum request size and one of which
        # is smaller than the maximum chunk size. Aggregate and pass these to
        # `_get_write_credential_infos`, validating that this method decomposes the aggregate
        # list into these expected chunks and makes 3 separate requests
        paths_chunk_1 = ["path1"] * 2000
        paths_chunk_2 = ["path2"] * 2000
        paths_chunk_3 = ["path3"] * 5
        credential_infos_mock_1 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_1", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(2000)
        ]
        credential_infos_mock_2 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_2", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(2000)
        ]
        credential_infos_mock_3 = [
            ArtifactCredentialInfo(
                signed_uri="http://mock_url_3", type=ArtifactCredentialType.AWS_PRESIGNED_URL
            )
            for _ in range(5)
        ]

        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json"
        ) as message_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint"
        ) as call_endpoint_mock:
            call_endpoint_mock.side_effect = [
                GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_1),
                GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_2),
                GetCredentialsForWrite.Response(credential_infos=credential_infos_mock_3),
            ]

            databricks_artifact_repo._get_write_credential_infos(
                MOCK_RUN_ID,
                paths_chunk_1 + paths_chunk_2 + paths_chunk_3,
            )
            assert call_endpoint_mock.call_count == message_mock.call_count == 3
            message_mock.assert_has_calls(
                [
                    mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_1)),
                    mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_2)),
                    mock.call(GetCredentialsForWrite(run_id=MOCK_RUN_ID, path=paths_chunk_3)),
                ]
            )
Пример #6
0
 def test_log_artifact_gcp_with_headers(
     self, databricks_artifact_repo, test_file, artifact_path, expected_location
 ):
     expected_headers = {header.name: header.value for header in MOCK_HEADERS}
     mock_response = Response()
     mock_response.status_code = 200
     mock_response.close = lambda: None
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
     ) as write_credential_infos_mock, mock.patch(
         "mlflow.utils.rest_utils.cloud_storage_http_request"
     ) as request_mock:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_GCP_SIGNED_URL,
             type=ArtifactCredentialType.GCP_SIGNED_URL,
             headers=MOCK_HEADERS,
         )
         write_credential_infos_mock.return_value = [mock_credential_info]
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path)
         write_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[expected_location]
         )
         request_mock.assert_called_with(
             "put", MOCK_GCP_SIGNED_URL, data=ANY, headers=expected_headers
         )
Пример #7
0
 def test_databricks_download_file_with_relative_path(
         self, remote_file_path, local_path):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock, mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY +
             "._get_read_credentials") as read_credentials_mock, mock.patch(
                 DATABRICKS_ARTIFACT_REPOSITORY +
                 ".list_artifacts") as get_list_mock, mock.patch(
                     DATABRICKS_ARTIFACT_REPOSITORY +
                     "._download_from_cloud") as download_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(
             MOCK_RUN_ID, posixpath.join(MOCK_SUBDIR, remote_file_path))
         download_mock.assert_called_with(mock_credentials, ANY)
Пример #8
0
 def test_log_artifact_aws_with_headers(self, databricks_artifact_repo,
                                        test_file, artifact_path,
                                        expected_location):
     expected_headers = {
         header.name: header.value
         for header in MOCK_HEADERS
     }
     mock_response = Response()
     mock_response.status_code = 200
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         "requests.put") as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL,
             headers=MOCK_HEADERS,
         )
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with(ANY, ANY, headers=expected_headers)
Пример #9
0
 def test_databricks_download_file_with_relative_path(self, remote_file_path, local_path):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
     ) as read_credential_infos_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
     ) as get_list_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
     ) as download_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
         )
         read_credential_infos_mock.return_value = [mock_credential_info]
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI)
         databricks_artifact_repo.download_artifacts(remote_file_path, local_path)
         read_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[posixpath.join(MOCK_SUBDIR, remote_file_path)]
         )
         download_mock.assert_called_with(
             cloud_credential_info=mock_credential_info, dst_local_file_path=ANY,
         )
Пример #10
0
 def test_log_artifact_with_relative_path(self, test_file, artifact_path,
                                          expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                     "._get_run_artifact_root"
                     ) as get_run_artifact_root_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._get_write_credentials"
                     ) as write_credentials_mock, mock.patch(
                         DATABRICKS_ARTIFACT_REPOSITORY +
                         "._upload_to_cloud") as upload_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         databricks_artifact_repo = get_artifact_repository(
             MOCK_SUBDIR_ROOT_URI)
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         upload_mock.assert_called_with(write_credentials_response_proto,
                                        test_file.strpath,
                                        expected_location)
Пример #11
0
 def test_log_artifact_gcp(self, databricks_artifact_repo, test_file,
                           artifact_path, expected_location):
     mock_response = Response()
     mock_response.status_code = 200
     mock_response.close = lambda: None
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials"
     ) as write_credentials_mock, mock.patch(
             "mlflow.utils.rest_utils.cloud_storage_http_request"
     ) as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_GCP_SIGNED_URL,
             type=ArtifactCredentialType.GCP_SIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with("put",
                                         MOCK_GCP_SIGNED_URL,
                                         ANY,
                                         headers={})
Пример #12
0
    def test_artifact_logging_awaits_upload_completion(
            self, databricks_artifact_repo, tmpdir):
        """
        Verifies that all asynchronous artifact uploads initiated by `log_artifact()` and
        `log_artifacts()` are joined before these methods return a result to the caller
        """
        src_dir = os.path.join(str(tmpdir), "src")
        os.makedirs(src_dir)
        src_file1_path = os.path.join(src_dir, "file_1.txt")
        with open(src_file1_path, "w") as f:
            f.write("file1")
        src_file2_path = os.path.join(src_dir, "file_2.txt")
        with open(src_file2_path, "w") as f:
            f.write("file2")

        dst_dir = os.path.join(str(tmpdir), "dst")
        os.makedirs(dst_dir)

        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_write_credentials"
                        ) as write_credentials_mock, mock.patch(
                            DATABRICKS_ARTIFACT_REPOSITORY +
                            "._upload_to_cloud") as upload_mock:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI)
            write_credentials_response_proto = GetCredentialsForWrite.Response(
                credentials=mock_credentials)
            write_credentials_mock.return_value = write_credentials_response_proto

            def mock_upload_to_cloud(cloud_credentials, local_file_path,
                                     artifact_path):  # pylint: disable=unused-argument
                # Sleep in order to simulate a longer-running asynchronous upload
                time.sleep(2)
                dst_path = os.path.join(dst_dir, artifact_path)
                os.makedirs(os.path.dirname(dst_path), exist_ok=True)
                shutil.copyfile(src=local_file_path, dst=dst_path)

            upload_mock.side_effect = mock_upload_to_cloud

            databricks_artifact_repo.log_artifacts(src_dir, "dir_artifact")

            expected_dst_dir_file1_path = os.path.join(dst_dir, "dir_artifact",
                                                       "file_1.txt")
            expected_dst_dir_file2_path = os.path.join(dst_dir, "dir_artifact",
                                                       "file_2.txt")
            assert os.path.exists(expected_dst_dir_file1_path)
            assert os.path.exists(expected_dst_dir_file2_path)
            with open(expected_dst_dir_file1_path, "r") as f:
                assert f.read() == "file1"
            with open(expected_dst_dir_file2_path, "r") as f:
                assert f.read() == "file2"

            databricks_artifact_repo.log_artifact(src_file1_path)

            expected_dst_file_path = os.path.join(dst_dir, "file_1.txt")
            assert os.path.exists(expected_dst_file_path)
            with open(expected_dst_file_path, "r") as f:
                assert f.read() == "file1"
Пример #13
0
    def test_download_artifacts_awaits_download_completion(self, databricks_artifact_repo, tmpdir):
        """
        Verifies that all asynchronous artifact downloads are joined before `download_artifacts()`
        returns a result to the caller
        """
        with mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
        ) as read_credential_infos_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
        ) as get_list_mock, mock.patch(
            DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
        ) as download_mock:
            ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
            )
            read_credential_infos_mock.return_value = [
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
                ArtifactCredentialInfo(
                    signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
                ),
            ]
            get_list_mock.return_value = [
                FileInfo(path="file_1.txt", is_dir=False, file_size=100),
                FileInfo(path="file_2.txt", is_dir=False, file_size=0),
            ]

            def mock_download_from_cloud(
                cloud_credential_info, dst_local_file_path
            ):  # pylint: disable=unused-argument
                # Sleep in order to simulate a longer-running asynchronous download
                time.sleep(2)
                with open(dst_local_file_path, "w") as f:
                    f.write("content")

            download_mock.side_effect = mock_download_from_cloud

            databricks_artifact_repo.download_artifacts("test_path", str(tmpdir))

            expected_file1_path = os.path.join(str(tmpdir), "file_1.txt")
            expected_file2_path = os.path.join(str(tmpdir), "file_2.txt")
            for path in [expected_file1_path, expected_file2_path]:
                assert os.path.exists(path)
                with open(path, "r") as f:
                    assert f.read() == "content"
Пример #14
0
 def test_log_artifact_azure_with_headers(self, databricks_artifact_repo,
                                          test_file, artifact_path,
                                          expected_location):
     mock_azure_headers = {
         "x-ms-encryption-scope": "test-scope",
         "x-ms-tags": "some-tags",
         "x-ms-blob-type": "some-type",
     }
     filtered_azure_headers = {
         "x-ms-encryption-scope": "test-scope",
         "x-ms-tags": "some-tags",
     }
     mock_response = Response()
     mock_response.status_code = 200
     mock_response.close = lambda: None
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credentials"
     ) as write_credentials_mock, mock.patch(
             "mlflow.utils.rest_utils.cloud_storage_http_request"
     ) as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI,
             headers=[
                 ArtifactCredentialInfo.HttpHeader(name=header_name,
                                                   value=header_value) for
                 header_name, header_value in mock_azure_headers.items()
             ],
         )
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.return_value = mock_response
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         request_mock.assert_called_with(
             "put",
             MOCK_AZURE_SIGNED_URI + "?comp=blocklist",
             ANY,
             headers=filtered_azure_headers,
         )
Пример #15
0
 def test_log_artifact_gcp_presigned_url_error(self, databricks_artifact_repo, test_file):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
     ) as write_credential_infos_mock, mock.patch(
         "mlflow.utils.rest_utils.cloud_storage_http_request"
     ) as request_mock:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_GCP_SIGNED_URL, type=ArtifactCredentialType.GCP_SIGNED_URL
         )
         write_credential_infos_mock.return_value = [mock_credential_info]
         request_mock.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credential_infos_mock.assert_called_with(run_id=MOCK_RUN_ID, paths=ANY)
Пример #16
0
 def test_log_artifact_azure_blob_client_sas_error(self, databricks_artifact_repo, test_file):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
     ) as write_credential_infos_mock, mock.patch(
         "azure.storage.blob.BlobClient.from_blob_url"
     ) as mock_create_blob_client:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
         )
         write_credential_infos_mock.return_value = [mock_credential_info]
         mock_create_blob_client.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credential_infos_mock.assert_called_with(run_id=MOCK_RUN_ID, paths=ANY)
 def test_log_artifact_aws_presigned_url_error(self,
                                               databricks_artifact_repo,
                                               test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch('requests.put') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         request_mock.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
 def test_log_artifact_azure_blob_client_sas_error(self,
                                                   databricks_artifact_repo,
                                                   test_file):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch(
                 'azure.storage.blob.BlobClient.from_blob_url') as mock_create_blob_client:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         mock_create_blob_client.side_effect = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.log_artifact(test_file.strpath)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID, ANY)
Пример #19
0
 def test_databricks_download_file_get_request_fail(self, databricks_artifact_repo, test_file):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
     ) as read_credential_infos_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
     ) as get_list_mock, mock.patch(
         "requests.get"
     ) as request_mock:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
         )
         read_credential_infos_mock.return_value = [mock_credential_info]
         get_list_mock.return_value = []
         request_mock.return_value = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.download_artifacts(test_file.strpath)
         read_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[test_file.strpath]
         )
 def test_databricks_download_file_get_request_fail(
         self, databricks_artifact_repo, test_file):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch('requests.get') as request_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI,
             type=ArtifactCredentialType.AZURE_SAS_URI)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         get_list_mock.return_value = []
         request_mock.return_value = MlflowException("MOCK ERROR")
         with pytest.raises(MlflowException):
             databricks_artifact_repo.download_artifacts(test_file.strpath)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  test_file.strpath)
 def test_log_artifact_aws(self, databricks_artifact_repo, test_file,
                           artifact_path, expected_location):
     with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_write_credentials') \
             as write_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._aws_upload_file') \
             as aws_upload_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AWS_SIGNED_URI,
             type=ArtifactCredentialType.AWS_PRESIGNED_URL)
         write_credentials_response_proto = GetCredentialsForWrite.Response(
             credentials=mock_credentials)
         write_credentials_mock.return_value = write_credentials_response_proto
         aws_upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath,
                                               artifact_path)
         write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                   expected_location)
         aws_upload_mock.assert_called_with(mock_credentials,
                                            test_file.strpath)
Пример #22
0
 def test_log_artifact_azure(
     self, databricks_artifact_repo, test_file, artifact_path, expected_location
 ):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
     ) as write_credential_infos_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._azure_upload_file"
     ) as azure_upload_mock:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
         )
         write_credential_infos_mock.return_value = [mock_credential_info]
         azure_upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path)
         write_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[expected_location]
         )
         azure_upload_mock.assert_called_with(
             mock_credential_info, test_file.strpath, expected_location
         )
 def test_databricks_download_file(self, databricks_artifact_repo,
                                   remote_file_path, local_path,
                                   cloud_credential_type):
     with mock.patch(
             DATABRICKS_ARTIFACT_REPOSITORY + '._get_read_credentials') \
             as read_credentials_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '.list_artifacts') as get_list_mock, \
             mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._download_from_cloud') \
             as download_mock:
         mock_credentials = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type)
         read_credentials_response_proto = GetCredentialsForRead.Response(
             credentials=mock_credentials)
         read_credentials_mock.return_value = read_credentials_response_proto
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo.download_artifacts(
             remote_file_path, local_path)
         read_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                  remote_file_path)
         download_mock.assert_called_with(mock_credentials, ANY)
Пример #24
0
    def test_log_artifact_azure_with_headers(self, databricks_artifact_repo,
                                             test_file, artifact_path,
                                             expected_location):
        expected_headers = {
            header.name: header.value
            for header in MOCK_HEADERS
        }
        mock_blob_service = mock.MagicMock(autospec=BlobClient)
        with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                        "._get_write_credentials"
                        ) as write_credentials_mock, mock.patch(
                            "azure.storage.blob.BlobClient.from_blob_url"
                        ) as mock_create_blob_client:
            mock_credentials = ArtifactCredentialInfo(
                signed_uri=MOCK_AZURE_SIGNED_URI,
                type=ArtifactCredentialType.AZURE_SAS_URI,
                headers=MOCK_HEADERS,
            )
            write_credentials_response_proto = GetCredentialsForWrite.Response(
                credentials=mock_credentials)
            write_credentials_mock.return_value = write_credentials_response_proto

            mock_create_blob_client.return_value = mock_blob_service
            mock_blob_service.stage_block.side_effect = None
            mock_blob_service.commit_block_list.side_effect = None

            databricks_artifact_repo.log_artifact(test_file.strpath,
                                                  artifact_path)
            write_credentials_mock.assert_called_with(MOCK_RUN_ID,
                                                      expected_location)
            mock_create_blob_client.assert_called_with(
                blob_url=MOCK_AZURE_SIGNED_URI,
                credential=None,
                headers=expected_headers)
            mock_blob_service.stage_block.assert_called_with(
                ANY, ANY, headers=expected_headers)
            mock_blob_service.commit_block_list.assert_called_with(
                ANY, headers=expected_headers)
Пример #25
0
 def test_databricks_download_file(
     self, databricks_artifact_repo, remote_file_path, local_path, cloud_credential_type
 ):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos"
     ) as read_credential_infos_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts"
     ) as get_list_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud"
     ) as download_mock:
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=cloud_credential_type
         )
         read_credential_infos_mock.return_value = [mock_credential_info]
         download_mock.return_value = None
         get_list_mock.return_value = []
         databricks_artifact_repo.download_artifacts(remote_file_path, local_path)
         read_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[remote_file_path]
         )
         download_mock.assert_called_with(
             cloud_credential_info=mock_credential_info, dst_local_file_path=ANY,
         )
Пример #26
0
 def test_log_artifact_with_relative_path(self, test_file, artifact_path, expected_location):
     with mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root"
     ) as get_run_artifact_root_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._get_write_credential_infos"
     ) as write_credential_infos_mock, mock.patch(
         DATABRICKS_ARTIFACT_REPOSITORY + "._upload_to_cloud"
     ) as upload_mock:
         get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
         databricks_artifact_repo = get_artifact_repository(MOCK_SUBDIR_ROOT_URI)
         mock_credential_info = ArtifactCredentialInfo(
             signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI
         )
         write_credential_infos_mock.return_value = [mock_credential_info]
         upload_mock.return_value = None
         databricks_artifact_repo.log_artifact(test_file.strpath, artifact_path)
         write_credential_infos_mock.assert_called_with(
             run_id=MOCK_RUN_ID, paths=[expected_location]
         )
         upload_mock.assert_called_with(
             cloud_credential_info=mock_credential_info,
             src_file_path=test_file.strpath,
             dst_run_relative_artifact_path=expected_location,
         )
from mlflow.exceptions import MlflowException
from mlflow.protos.databricks_artifacts_pb2 import GetCredentialsForWrite, GetCredentialsForRead, \
    ArtifactCredentialType, ArtifactCredentialInfo
from mlflow.protos.service_pb2 import ListArtifacts, FileInfo
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository

DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = 'mlflow.store.artifact.databricks_artifact_repo'
DATABRICKS_ARTIFACT_REPOSITORY = DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + \
                                 ".DatabricksArtifactRepository"

MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure"
MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?"
MOCK_RUN_ID = "MOCK-RUN-ID"
MOCK_HEADERS = [
    ArtifactCredentialInfo.HttpHeader(name='Mock-Name1', value='Mock-Value1'),
    ArtifactCredentialInfo.HttpHeader(name='Mock-Name2', value='Mock-Value2')
]
MOCK_RUN_ROOT_URI = \
    "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts"
MOCK_SUBDIR = "subdir/path"
MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR)


@pytest.fixture()
def databricks_artifact_repo():
    with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \
            as get_run_artifact_root_mock:
        get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
        return get_artifact_repository(
            "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")
Пример #28
0
    ArtifactCredentialType,
    ArtifactCredentialInfo,
)
from mlflow.protos.service_pb2 import ListArtifacts, FileInfo
from mlflow.store.artifact.artifact_repository_registry import get_artifact_repository
from mlflow.store.artifact.dbfs_artifact_repo import DatabricksArtifactRepository

DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE = "mlflow.store.artifact.databricks_artifact_repo"
DATABRICKS_ARTIFACT_REPOSITORY = (DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE +
                                  ".DatabricksArtifactRepository")

MOCK_AZURE_SIGNED_URI = "http://this_is_a_mock_sas_for_azure"
MOCK_AWS_SIGNED_URI = "http://this_is_a_mock_presigned_uri_for_aws?"
MOCK_RUN_ID = "MOCK-RUN-ID"
MOCK_HEADERS = [
    ArtifactCredentialInfo.HttpHeader(name="Mock-Name1", value="Mock-Value1"),
    ArtifactCredentialInfo.HttpHeader(name="Mock-Name2", value="Mock-Value2"),
]
MOCK_RUN_ROOT_URI = "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts"
MOCK_SUBDIR = "subdir/path"
MOCK_SUBDIR_ROOT_URI = posixpath.join(MOCK_RUN_ROOT_URI, MOCK_SUBDIR)


@pytest.fixture()
def databricks_artifact_repo():
    with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY +
                    "._get_run_artifact_root") as get_run_artifact_root_mock:
        get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI
        return get_artifact_repository(
            "dbfs:/databricks/mlflow-tracking/MOCK-EXP/MOCK-RUN-ID/artifacts")