예제 #1
0
def test_list_artifacts(gcs_mock):
    dest_path = "/some/path/"
    repo = GCSArtifactRepository("gs://test_bucket" + dest_path, gcs_mock)

    # mocking a single blob returned by bucket.list_blobs iterator
    # https://googlecloudplatform.github.io/google-cloud-python/latest/storage/buckets.html#google.cloud.storage.bucket.Bucket.list_blobs
    dir_mock = mock.Mock()
    dir_name = "0_subpath"
    dir_mock.configure_mock(prefixes=(dest_path + dir_name + "/",))

    obj_mock = mock.Mock()
    file_name = '1_file'
    obj_mock.configure_mock(name=dest_path + file_name, size=1)

    mock_results = mock.MagicMock()
    mock_results.configure_mock(pages=[dir_mock])
    mock_results.__iter__.return_value = [obj_mock]

    gcs_mock.Client.return_value.get_bucket.return_value\
        .list_blobs.return_value = mock_results

    artifacts = repo.list_artifacts()
    assert artifacts[0].path == dir_name
    assert artifacts[0].is_dir is True
    assert artifacts[0].file_size is None
    assert artifacts[1].path == file_name
    assert artifacts[1].is_dir is False
    assert artifacts[1].file_size == obj_mock.size
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir):
    artifact_root_path = "/experiment_id/run_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path,
                                 gcs_mock)

    obj_mock_1 = mock.Mock()
    file_path_1 = 'file1'
    obj_mock_1.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_1),
                              size=1)
    obj_mock_2 = mock.Mock()
    file_path_2 = 'file2'
    obj_mock_2.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_2),
                              size=1)
    mock_populated_results = mock.MagicMock()
    mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2]

    mock_empty_results = mock.MagicMock()
    mock_empty_results.__iter__.return_value = []

    def get_mock_listing(prefix, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        prefix = os.path.join("/", prefix)
        if os.path.abspath(prefix) == os.path.abspath(artifact_root_path):
            return mock_populated_results
        else:
            return mock_empty_results

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    gcs_mock.Client.return_value.get_bucket.return_value\
        .list_blobs.side_effect = get_mock_listing

    gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
        .download_to_filename.side_effect = mkfile

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents
def test_get_anonymous_bucket(gcs_mock):
    with pytest.raises(DefaultCredentialsError, match='Test'):
        gcs_mock.Client.return_value\
            .get_bucket.side_effect = \
            mock.Mock(side_effect=DefaultCredentialsError('Test'))
        repo = GCSArtifactRepository("gs://test_bucket", gcs_mock)
        repo._get_bucket("gs://test_bucket")
        anon_call_count = gcs_mock.Client\
            .create_anonymous_client.call_count
        assert anon_call_count == 1
        bucket_call_count = gcs_mock.Client\
            .create_anonymous_client.return_value\
            .get_bucket.call_count
        assert bucket_call_count == 1
예제 #4
0
 def from_artifact_uri(artifact_uri, store):
     """
     Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket),
     returns an ArtifactReposistory instance capable of logging and downloading artifacts
     on behalf of this URI.
     :param store: An instance of AbstractStore which the artifacts are registered in.
     """
     if artifact_uri.startswith("s3:/"):
         # Import these locally to avoid creating a circular import loop
         from mlflow.store.s3_artifact_repo import S3ArtifactRepository
         return S3ArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("gs:/"):
         from mlflow.store.gcs_artifact_repo import GCSArtifactRepository
         return GCSArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("wasbs:/"):
         from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository
         return AzureBlobArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("sftp:/"):
         from mlflow.store.sftp_artifact_repo import SFTPArtifactRepository
         return SFTPArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("dbfs:/"):
         from mlflow.store.dbfs_artifact_repo import DbfsArtifactRepository
         if not isinstance(store, DatabricksStore):
             raise MlflowException(
                 '`store` must be an instance of DatabricksStore.')
         return DbfsArtifactRepository(artifact_uri,
                                       store.http_request_kwargs)
     else:
         from mlflow.store.local_artifact_repo import LocalArtifactRepository
         return LocalArtifactRepository(artifact_uri)
예제 #5
0
def test_download_artifacts(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    def mkfile(fname):
        fname = fname.replace(tmpdir.strpath, '')
        f = tmpdir.join(fname)
        f.write("hello world!")
        return f.strpath

    gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
        .download_to_filename.side_effect = mkfile

    open(repo._download_artifacts_into("test.txt", tmpdir.strpath)).read()
    gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
    gcs_mock.Client().get_bucket().get_blob\
        .assert_called_with('some/path/test.txt')
    gcs_mock.Client().get_bucket().get_blob()\
        .download_to_filename.assert_called_with(tmpdir + "/test.txt")
예제 #6
0
def test_log_artifacts(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    subd = tmpdir.mkdir("data").mkdir("subdir")
    subd.join("a.txt").write("A")
    subd.join("b.txt").write("B")
    subd.join("c.txt").write("C")

    gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\
        .upload_from_filename.side_effect = os.path.isfile
    repo.log_artifacts(subd.strpath)

    gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
    gcs_mock.Client().get_bucket().blob().upload_from_filename\
        .assert_has_calls([
            mock.call('%s/a.txt' % subd.strpath),
            mock.call('%s/b.txt' % subd.strpath),
            mock.call('%s/c.txt' % subd.strpath),
        ], any_order=True)
예제 #7
0
def test_log_artifact(gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    d = tmpdir.mkdir("data")
    f = d.join("test.txt")
    f.write("hello world!")
    fpath = d + '/test.txt'
    fpath = fpath.strpath

    # This will call isfile on the code path being used,
    # thus testing that it's being called with an actually file path
    gcs_mock.Client.return_value.get_bucket.return_value.blob.return_value\
        .upload_from_filename.side_effect = os.path.isfile
    repo.log_artifact(fpath)

    gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
    gcs_mock.Client().get_bucket().blob\
        .assert_called_with('some/path/test.txt')
    gcs_mock.Client().get_bucket().blob().upload_from_filename\
        .assert_called_with(fpath)
def test_list_artifacts(gcs_mock):
    artifact_root_path = "/experiment_id/run_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path,
                                 gcs_mock)

    # mocked bucket/blob structure
    # gs://test_bucket/experiment_id/run_id/
    #  |- file
    #  |- model
    #     |- model.pb

    # mocking a single blob returned by bucket.list_blobs iterator
    # https://googlecloudplatform.github.io/google-cloud-python/latest/storage/buckets.html#google.cloud.storage.bucket.Bucket.list_blobs

    # list artifacts at artifact root level
    obj_mock = mock.Mock()
    file_path = 'file'
    obj_mock.configure_mock(name=artifact_root_path + file_path, size=1)

    dir_mock = mock.Mock()
    dir_name = "model"
    dir_mock.configure_mock(prefixes=(artifact_root_path + dir_name + "/", ))

    mock_results = mock.MagicMock()
    mock_results.configure_mock(pages=[dir_mock])
    mock_results.__iter__.return_value = [obj_mock]

    gcs_mock.Client.return_value.get_bucket.return_value\
        .list_blobs.return_value = mock_results

    artifacts = repo.list_artifacts(path=None)

    assert len(artifacts) == 2
    assert artifacts[0].path == file_path
    assert artifacts[0].is_dir is False
    assert artifacts[0].file_size == obj_mock.size
    assert artifacts[1].path == dir_name
    assert artifacts[1].is_dir is True
    assert artifacts[1].file_size is None
def test_download_artifacts_calls_expected_gcs_client_methods(
        gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
        .download_to_filename.side_effect = mkfile

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
    gcs_mock.Client().get_bucket().get_blob\
        .assert_called_with('some/path/test.txt')
    download_calls = \
        gcs_mock.Client().get_bucket().get_blob().download_to_filename.call_args_list
    assert len(download_calls) == 1
    download_path_arg = download_calls[0][0][0]
    assert "test.txt" in download_path_arg
def test_list_artifacts_with_subdir(gcs_mock):
    artifact_root_path = "/experiment_id/run_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path,
                                 gcs_mock)

    # mocked bucket/blob structure
    # gs://test_bucket/experiment_id/run_id/
    #  |- model
    #     |- model.pb
    #     |- variables

    # list artifacts at sub directory level
    dir_name = "model"
    obj_mock = mock.Mock()
    file_path = dir_name + "/" + 'model.pb'
    obj_mock.configure_mock(name=artifact_root_path + file_path, size=1)

    subdir_mock = mock.Mock()
    subdir_name = dir_name + "/" + 'variables'
    subdir_mock.configure_mock(prefixes=(artifact_root_path + subdir_name +
                                         "/", ))

    mock_results = mock.MagicMock()
    mock_results.configure_mock(pages=[subdir_mock])
    mock_results.__iter__.return_value = [obj_mock]

    gcs_mock.Client.return_value.get_bucket.return_value\
        .list_blobs.return_value = mock_results

    artifacts = repo.list_artifacts(path=dir_name)
    assert len(artifacts) == 2
    assert artifacts[0].path == file_path
    assert artifacts[0].is_dir is False
    assert artifacts[0].file_size == obj_mock.size
    assert artifacts[1].path == subdir_name
    assert artifacts[1].is_dir is True
    assert artifacts[1].file_size is None
예제 #11
0
 def from_artifact_uri(artifact_uri):
     """
     Given an artifact URI for an Experiment Run (e.g., /local/file/path or s3://my/bucket),
     returns an ArtifactReposistory instance capable of logging and downloading artifacts
     on behalf of this URI.
     """
     if artifact_uri.startswith("s3:/"):
         # Import these locally to avoid creating a circular import loop
         from mlflow.store.s3_artifact_repo import S3ArtifactRepository
         return S3ArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("gs:/"):
         from mlflow.store.gcs_artifact_repo import GCSArtifactRepository
         return GCSArtifactRepository(artifact_uri)
     elif artifact_uri.startswith("wasbs:/"):
         from mlflow.store.azure_blob_artifact_repo import AzureBlobArtifactRepository
         return AzureBlobArtifactRepository(artifact_uri)
     else:
         from mlflow.store.local_artifact_repo import LocalArtifactRepository
         return LocalArtifactRepository(artifact_uri)
예제 #12
0
def test_list_artifacts_empty(gcs_mock):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)
    gcs_mock.Client.return_value.get_bucket.return_value \
        .list_blobs.return_value = mock.MagicMock()
    assert repo.list_artifacts() == []