def test_list_artifacts(gcs_mock): dest_path = "/some/path/" repo = GCSArtifactRepository("gs://test_bucket" + dest_path, gcs_mock) # mocking a single blob returned by bucket.list_blobs iterator # https://googlecloudplatform.github.io/google-cloud-python/latest/storage/buckets.html#google.cloud.storage.bucket.Bucket.list_blobs dir_mock = mock.Mock() dir_name = "0_subpath" dir_mock.configure_mock(prefixes=(dest_path + dir_name + "/",)) obj_mock = mock.Mock() file_name = '1_file' obj_mock.configure_mock(name=dest_path + file_name, size=1) mock_results = mock.MagicMock() mock_results.configure_mock(pages=[dir_mock]) mock_results.__iter__.return_value = [obj_mock] gcs_mock.Client.return_value.get_bucket.return_value\ .list_blobs.return_value = mock_results artifacts = repo.list_artifacts() assert artifacts[0].path == dir_name assert artifacts[0].is_dir is True assert artifacts[0].file_size is None assert artifacts[1].path == file_name assert artifacts[1].is_dir is False assert artifacts[1].file_size == obj_mock.size
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir): artifact_root_path = "/experiment_id/run_id/" repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path, gcs_mock) obj_mock_1 = mock.Mock() file_path_1 = 'file1' obj_mock_1.configure_mock(name=os.path.join(artifact_root_path, file_path_1), size=1) obj_mock_2 = mock.Mock() file_path_2 = 'file2' obj_mock_2.configure_mock(name=os.path.join(artifact_root_path, file_path_2), size=1) mock_populated_results = mock.MagicMock() mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2] mock_empty_results = mock.MagicMock() mock_empty_results.__iter__.return_value = [] def get_mock_listing(prefix, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument prefix = os.path.join("/", prefix) if os.path.abspath(prefix) == os.path.abspath(artifact_root_path): return mock_populated_results else: return mock_empty_results def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") gcs_mock.Client.return_value.get_bucket.return_value\ .list_blobs.side_effect = get_mock_listing gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\ .download_to_filename.side_effect = mkfile # Ensure that the root directory can be downloaded successfully repo.download_artifacts("") # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir` dir_contents = os.listdir(tmpdir.strpath) assert file_path_1 in dir_contents assert file_path_2 in dir_contents
def test_get_anonymous_bucket(gcs_mock): with pytest.raises(DefaultCredentialsError, match='Test'): gcs_mock.Client.return_value\ .get_bucket.side_effect = \ mock.Mock(side_effect=DefaultCredentialsError('Test')) repo = GCSArtifactRepository("gs://test_bucket", gcs_mock) repo._get_bucket("gs://test_bucket") anon_call_count = gcs_mock.Client\ .create_anonymous_client.call_count assert anon_call_count == 1 bucket_call_count = gcs_mock.Client\ .create_anonymous_client.return_value\ .get_bucket.call_count assert bucket_call_count == 1
def from_artifact_uri(artifact_uri, store): """ Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket), returns an ArtifactReposistory instance capable of logging and downloading artifacts on behalf of this URI. :param store: An instance of AbstractStore which the artifacts are registered in. """ if artifact_uri.startswith("s3:/"): # Import these locally to avoid creating a circular import loop from mlflow.store.s3_artifact_repo import S3ArtifactRepository return S3ArtifactRepository(artifact_uri) elif artifact_uri.startswith("gs:/"): from mlflow.store.gcs_artifact_repo import GCSArtifactRepository return GCSArtifactRepository(artifact_uri) elif artifact_uri.startswith("wasbs:/"): from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository return AzureBlobArtifactRepository(artifact_uri) elif artifact_uri.startswith("sftp:/"): from mlflow.store.sftp_artifact_repo import SFTPArtifactRepository return SFTPArtifactRepository(artifact_uri) elif artifact_uri.startswith("dbfs:/"): from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository if not isinstance(store, DatabricksStore): raise MlflowException( '`store` must be an instance of DatabricksStore.') return DbfsArtifactRepository(artifact_uri, store.http_request_kwargs) else: from mlflow.store.local_artifact_repo import LocalArtifactRepository return LocalArtifactRepository(artifact_uri)
def test_download_artifacts(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) def mkfile(fname): fname = fname.replace(tmpdir.strpath, '') f = tmpdir.join(fname) f.write("hello world!") return f.strpath gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\ .download_to_filename.side_effect = mkfile open(repo._download_artifacts_into("test.txt", tmpdir.strpath)).read() gcs_mock.Client().get_bucket.assert_called_with('test_bucket') gcs_mock.Client().get_bucket().get_blob\ .assert_called_with('some/path/test.txt') gcs_mock.Client().get_bucket().get_blob()\ .download_to_filename.assert_called_with(tmpdir + "/test.txt")
def test_log_artifacts(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) subd = tmpdir.mkdir("data").mkdir("subdir") subd.join("a.txt").write("A") subd.join("b.txt").write("B") subd.join("c.txt").write("C") gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\ .upload_from_filename.side_effect = os.path.isfile repo.log_artifacts(subd.strpath) gcs_mock.Client().get_bucket.assert_called_with('test_bucket') gcs_mock.Client().get_bucket().blob().upload_from_filename\ .assert_has_calls([ mock.call('%s/a.txt' % subd.strpath), mock.call('%s/b.txt' % subd.strpath), mock.call('%s/c.txt' % subd.strpath), ], any_order=True)
def test_log_artifact(gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) d = tmpdir.mkdir("data") f = d.join("test.txt") f.write("hello world!") fpath = d + '/test.txt' fpath = fpath.strpath # This will call isfile on the code path being used, # thus testing that it's being called with an actually file path gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\ .upload_from_filename.side_effect = os.path.isfile repo.log_artifact(fpath) gcs_mock.Client().get_bucket.assert_called_with('test_bucket') gcs_mock.Client().get_bucket().blob\ .assert_called_with('some/path/test.txt') gcs_mock.Client().get_bucket().blob().upload_from_filename\ .assert_called_with(fpath)
def test_list_artifacts(gcs_mock): artifact_root_path = "/experiment_id/run_id/" repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path, gcs_mock) # mocked bucket/blob structure # gs://test_bucket/experiment_id/run_id/ # |- file # |- model # |- model.pb # mocking a single blob returned by bucket.list_blobs iterator # https://googlecloudplatform.github.io/google-cloud-python/latest/storage/buckets.html#google.cloud.storage.bucket.Bucket.list_blobs # list artifacts at artifact root level obj_mock = mock.Mock() file_path = 'file' obj_mock.configure_mock(name=artifact_root_path + file_path, size=1) dir_mock = mock.Mock() dir_name = "model" dir_mock.configure_mock(prefixes=(artifact_root_path + dir_name + "/", )) mock_results = mock.MagicMock() mock_results.configure_mock(pages=[dir_mock]) mock_results.__iter__.return_value = [obj_mock] gcs_mock.Client.return_value.get_bucket.return_value\ .list_blobs.return_value = mock_results artifacts = repo.list_artifacts(path=None) assert len(artifacts) == 2 assert artifacts[0].path == file_path assert artifacts[0].is_dir is False assert artifacts[0].file_size == obj_mock.size assert artifacts[1].path == dir_name assert artifacts[1].is_dir is True assert artifacts[1].file_size is None
def test_download_artifacts_calls_expected_gcs_client_methods( gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\ .download_to_filename.side_effect = mkfile repo.download_artifacts("test.txt") assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt")) gcs_mock.Client().get_bucket.assert_called_with('test_bucket') gcs_mock.Client().get_bucket().get_blob\ .assert_called_with('some/path/test.txt') download_calls = \ gcs_mock.Client().get_bucket().get_blob().download_to_filename.call_args_list assert len(download_calls) == 1 download_path_arg = download_calls[0][0][0] assert "test.txt" in download_path_arg
def test_list_artifacts_with_subdir(gcs_mock): artifact_root_path = "/experiment_id/run_id/" repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path, gcs_mock) # mocked bucket/blob structure # gs://test_bucket/experiment_id/run_id/ # |- model # |- model.pb # |- variables # list artifacts at sub directory level dir_name = "model" obj_mock = mock.Mock() file_path = dir_name + "/" + 'model.pb' obj_mock.configure_mock(name=artifact_root_path + file_path, size=1) subdir_mock = mock.Mock() subdir_name = dir_name + "/" + 'variables' subdir_mock.configure_mock(prefixes=(artifact_root_path + subdir_name + "/", )) mock_results = mock.MagicMock() mock_results.configure_mock(pages=[subdir_mock]) mock_results.__iter__.return_value = [obj_mock] gcs_mock.Client.return_value.get_bucket.return_value\ .list_blobs.return_value = mock_results artifacts = repo.list_artifacts(path=dir_name) assert len(artifacts) == 2 assert artifacts[0].path == file_path assert artifacts[0].is_dir is False assert artifacts[0].file_size == obj_mock.size assert artifacts[1].path == subdir_name assert artifacts[1].is_dir is True assert artifacts[1].file_size is None
def from_artifact_uri(artifact_uri): """ Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket), returns an ArtifactReposistory instance capable of logging and downloading artifacts on behalf of this URI. """ if artifact_uri.startswith("s3:/"): # Import these locally to avoid creating a circular import loop from mlflow.store.s3_artifact_repo import S3ArtifactRepository return S3ArtifactRepository(artifact_uri) elif artifact_uri.startswith("gs:/"): from mlflow.store.gcs_artifact_repo import GCSArtifactRepository return GCSArtifactRepository(artifact_uri) elif artifact_uri.startswith("wasbs:/"): from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository return AzureBlobArtifactRepository(artifact_uri) else: from mlflow.store.local_artifact_repo import LocalArtifactRepository return LocalArtifactRepository(artifact_uri)
def test_list_artifacts_empty(gcs_mock): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) gcs_mock.Client.return_value.get_bucket.return_value \ .list_blobs.return_value = mock.MagicMock() assert repo.list_artifacts() == []