Ejemplo n.º 1
0
def test_download_artifacts_calls_expected_gcs_client_methods(
        gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    mock_method_chain(
        gcs_mock,
        [
            "Client",
            "bucket",
            "blob",
            "download_to_filename",
        ],
        side_effect=mkfile,
    )

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    gcs_mock.Client().bucket.assert_called_with("test_bucket")
    gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt")
    download_calls = gcs_mock.Client().bucket().blob(
    ).download_to_filename.call_args_list
    assert len(download_calls) == 1
    download_path_arg = download_calls[0][0][0]
    assert "test.txt" in download_path_arg
Ejemplo n.º 2
0
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir):
    artifact_root_path = "/experiment_id/run_id/"
    repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path,
                                 gcs_mock)

    obj_mock_1 = mock.Mock()
    file_path_1 = 'file1'
    obj_mock_1.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_1),
                              size=1)
    obj_mock_2 = mock.Mock()
    file_path_2 = 'file2'
    obj_mock_2.configure_mock(name=os.path.join(artifact_root_path,
                                                file_path_2),
                              size=1)
    mock_populated_results = mock.MagicMock()
    mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2]

    mock_empty_results = mock.MagicMock()
    mock_empty_results.__iter__.return_value = []

    def get_mock_listing(prefix, **kwargs):
        """
        Produces a mock listing that only contains content if the
        specified prefix is the artifact root. This allows us to mock
        `list_artifacts` during the `_download_artifacts_into` subroutine
        without recursively listing the same artifacts at every level of the
        directory traversal.
        """
        # pylint: disable=unused-argument
        prefix = os.path.join("/", prefix)
        if os.path.abspath(prefix) == os.path.abspath(artifact_root_path):
            return mock_populated_results
        else:
            return mock_empty_results

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    gcs_mock.Client.return_value.get_bucket.return_value\
        .list_blobs.side_effect = get_mock_listing

    gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
        .download_to_filename.side_effect = mkfile

    # Ensure that the root directory can be downloaded successfully
    repo.download_artifacts("")
    # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir`
    dir_contents = os.listdir(tmpdir.strpath)
    assert file_path_1 in dir_contents
    assert file_path_2 in dir_contents
Ejemplo n.º 3
0
def test_download_artifacts_calls_expected_gcs_client_methods(
        gcs_mock, tmpdir):
    repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock)

    def mkfile(fname):
        fname = os.path.basename(fname)
        f = tmpdir.join(fname)
        f.write("hello world!")

    gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\
        .download_to_filename.side_effect = mkfile

    repo.download_artifacts("test.txt")
    assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt"))
    gcs_mock.Client().get_bucket.assert_called_with('test_bucket')
    gcs_mock.Client().get_bucket().get_blob\
        .assert_called_with('some/path/test.txt')
    download_calls = \
        gcs_mock.Client().get_bucket().get_blob().download_to_filename.call_args_list
    assert len(download_calls) == 1
    download_path_arg = download_calls[0][0][0]
    assert "test.txt" in download_path_arg