def test_list_artifacts(self, databricks_artifact_repo): list_artifact_file_proto_mock = [ FileInfo(path="a.txt", is_dir=False, file_size=0) ] list_artifacts_dir_proto_mock = [ FileInfo(path="test/a.txt", is_dir=False, file_size=100), FileInfo(path="test/dir", is_dir=True, file_size=0), ] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint") as call_endpoint_mock: list_artifact_response_proto = ListArtifacts.Response( root_uri="", files=list_artifacts_dir_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts("test/") assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 assert artifacts[0].path == "test/a.txt" assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 assert artifacts[1].path == "test/dir" assert artifacts[1].is_dir is True assert artifacts[1].file_size is None # Calling list_artifacts() on a path that's a file should return an empty list list_artifact_response_proto = ListArtifacts.Response( root_uri="", files=list_artifact_file_proto_mock) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts("a.txt") assert len(artifacts) == 0
def test_download_artifacts_provides_failure_info(self, databricks_artifact_repo): with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credentials_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] read_credentials_mock.return_value = [ ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ] download_mock.side_effect = [ MlflowException("MOCK ERROR 1"), MlflowException("MOCK ERROR 2"), ] with pytest.raises(MlflowException) as exc: databricks_artifact_repo.download_artifacts("test_path") err_msg = str(exc.value) assert MOCK_RUN_ROOT_URI in err_msg assert "file_1.txt" in err_msg assert "MOCK ERROR 1" in err_msg assert "file_2.txt" in err_msg assert "MOCK ERROR 2" in err_msg
def to_proto(self): proto = ProtoFileInfo() proto.path = self.path proto.is_dir = self.is_dir if self.file_size: proto.file_size = self.file_size return proto
def test_list_artifacts_with_relative_path(self): list_artifact_file_proto_mock = [ FileInfo(path=posixpath.join(MOCK_SUBDIR, "a.txt"), is_dir=False, file_size=0) ] list_artifacts_dir_proto_mock = [ FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/a.txt"), is_dir=False, file_size=100), FileInfo(path=posixpath.join(MOCK_SUBDIR, "test/dir"), is_dir=True, file_size=0), ] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + "._get_run_artifact_root" ) as get_run_artifact_root_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json") as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint") as call_endpoint_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI list_artifact_response_proto = ListArtifacts.Response( root_uri="", files=list_artifacts_dir_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto message_mock.return_value = None databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) artifacts = databricks_artifact_repo.list_artifacts("test") assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 assert artifacts[0].path == "test/a.txt" assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 assert artifacts[1].path == "test/dir" assert artifacts[1].is_dir is True assert artifacts[1].file_size is None message_mock.assert_called_with( ListArtifacts(run_id=MOCK_RUN_ID, path=posixpath.join(MOCK_SUBDIR, "test"))) # Calling list_artifacts() on a relative path that's a file should return an empty list list_artifact_response_proto = ListArtifacts.Response( root_uri="", files=list_artifact_file_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto artifacts = databricks_artifact_repo.list_artifacts("a.txt") assert len(artifacts) == 0
def test_download_artifacts_awaits_download_completion(self, databricks_artifact_repo, tmpdir): """ Verifies that all asynchronous artifact downloads are joined before `download_artifacts()` returns a result to the caller """ with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._get_read_credential_infos" ) as read_credential_infos_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + ".list_artifacts" ) as get_list_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._download_from_cloud" ) as download_mock: ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ) read_credential_infos_mock.return_value = [ ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ArtifactCredentialInfo( signed_uri=MOCK_AZURE_SIGNED_URI, type=ArtifactCredentialType.AZURE_SAS_URI ), ] get_list_mock.return_value = [ FileInfo(path="file_1.txt", is_dir=False, file_size=100), FileInfo(path="file_2.txt", is_dir=False, file_size=0), ] def mock_download_from_cloud( cloud_credential_info, dst_local_file_path ): # pylint: disable=unused-argument # Sleep in order to simulate a longer-running asynchronous download time.sleep(2) with open(dst_local_file_path, "w") as f: f.write("content") download_mock.side_effect = mock_download_from_cloud databricks_artifact_repo.download_artifacts("test_path", str(tmpdir)) expected_file1_path = os.path.join(str(tmpdir), "file_1.txt") expected_file2_path = os.path.join(str(tmpdir), "file_2.txt") for path in [expected_file1_path, expected_file2_path]: assert os.path.exists(path) with open(path, "r") as f: assert f.read() == "content"
def test_list_artifacts_with_relative_path(self): list_artifacts_dir_proto_mock = [ FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/a.txt'), is_dir=False, file_size=100), FileInfo(path=posixpath.join(MOCK_SUBDIR, 'test/dir'), is_dir=True, file_size=0) ] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY + '._get_run_artifact_root') \ as get_run_artifact_root_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: get_run_artifact_root_mock.return_value = MOCK_RUN_ROOT_URI list_artifact_response_proto = \ ListArtifacts.Response(root_uri='', files=list_artifacts_dir_proto_mock, next_page_token=None) call_endpoint_mock.return_value = list_artifact_response_proto message_mock.return_value = None databricks_artifact_repo = get_artifact_repository( MOCK_SUBDIR_ROOT_URI) artifacts = databricks_artifact_repo.list_artifacts('test') assert isinstance(artifacts, list) assert isinstance(artifacts[0], FileInfoEntity) assert len(artifacts) == 2 assert artifacts[0].path == 'test/a.txt' assert artifacts[0].is_dir is False assert artifacts[0].file_size == 100 assert artifacts[1].path == 'test/dir' assert artifacts[1].is_dir is True assert artifacts[1].file_size is None message_mock.assert_called_with( ListArtifacts(run_id=MOCK_RUN_ID, path=posixpath.join(MOCK_SUBDIR, "test")))
def test_paginated_list_artifacts(self, databricks_artifact_repo): list_artifacts_proto_mock_1 = [ FileInfo(path="a.txt", is_dir=False, file_size=100), FileInfo(path="b", is_dir=True, file_size=0), ] list_artifacts_proto_mock_2 = [ FileInfo(path="c.txt", is_dir=False, file_size=100), FileInfo(path="d", is_dir=True, file_size=0), ] list_artifacts_proto_mock_3 = [ FileInfo(path="e.txt", is_dir=False, file_size=100), FileInfo(path="f", is_dir=True, file_size=0), ] list_artifacts_proto_mock_4 = [] with mock.patch(DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + ".message_to_json") as message_mock, mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + "._call_endpoint") as call_endpoint_mock: list_artifact_paginated_response_protos = [ ListArtifacts.Response(root_uri="", files=list_artifacts_proto_mock_1, next_page_token="2"), ListArtifacts.Response(root_uri="", files=list_artifacts_proto_mock_2, next_page_token="4"), ListArtifacts.Response(root_uri="", files=list_artifacts_proto_mock_3, next_page_token="6"), ListArtifacts.Response(root_uri="", files=list_artifacts_proto_mock_4, next_page_token="8"), ] call_endpoint_mock.side_effect = list_artifact_paginated_response_protos message_mock.return_value = None artifacts = databricks_artifact_repo.list_artifacts() assert set(["a.txt", "b", "c.txt", "d", "e.txt", "f"]) == set([file.path for file in artifacts]) calls = [ mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="")), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token="2")), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token="4")), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token="6")), ] message_mock.assert_has_calls(calls)
def test_paginated_list_artifacts(self, databricks_artifact_repo): list_artifacts_proto_mock_1 = [ FileInfo(path='a.txt', is_dir=False, file_size=100), FileInfo(path='b', is_dir=True, file_size=0) ] list_artifacts_proto_mock_2 = [ FileInfo(path='c.txt', is_dir=False, file_size=100), FileInfo(path='d', is_dir=True, file_size=0) ] list_artifacts_proto_mock_3 = [ FileInfo(path='e.txt', is_dir=False, file_size=100), FileInfo(path='f', is_dir=True, file_size=0) ] list_artifacts_proto_mock_4 = [] with mock.patch( DATABRICKS_ARTIFACT_REPOSITORY_PACKAGE + '.message_to_json')as message_mock, \ mock.patch( DATABRICKS_ARTIFACT_REPOSITORY + '._call_endpoint') as call_endpoint_mock: list_artifact_paginated_response_protos = [ ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_1, next_page_token='2'), ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_2, next_page_token='4'), ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_3, next_page_token='6'), ListArtifacts.Response(root_uri='', files=list_artifacts_proto_mock_4, next_page_token='8'), ] call_endpoint_mock.side_effect = list_artifact_paginated_response_protos message_mock.return_value = None artifacts = databricks_artifact_repo.list_artifacts() assert set(['a.txt', 'b', 'c.txt', 'd', 'e.txt', 'f']) == set([file.path for file in artifacts]) calls = [ mock.call(ListArtifacts(run_id=MOCK_RUN_ID, path="")), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='2')), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='4')), mock.call( ListArtifacts(run_id=MOCK_RUN_ID, path="", page_token='6')) ] message_mock.assert_has_calls(calls)