def test_download_artifacts_calls_expected_gcs_client_methods( gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") mock_method_chain( gcs_mock, [ "Client", "bucket", "blob", "download_to_filename", ], side_effect=mkfile, ) repo.download_artifacts("test.txt") assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt")) gcs_mock.Client().bucket.assert_called_with("test_bucket") gcs_mock.Client().bucket().blob.assert_called_with("some/path/test.txt") download_calls = gcs_mock.Client().bucket().blob( ).download_to_filename.call_args_list assert len(download_calls) == 1 download_path_arg = download_calls[0][0][0] assert "test.txt" in download_path_arg
def test_download_artifacts_downloads_expected_content(gcs_mock, tmpdir): artifact_root_path = "/experiment_id/run_id/" repo = GCSArtifactRepository("gs://test_bucket" + artifact_root_path, gcs_mock) obj_mock_1 = mock.Mock() file_path_1 = 'file1' obj_mock_1.configure_mock(name=os.path.join(artifact_root_path, file_path_1), size=1) obj_mock_2 = mock.Mock() file_path_2 = 'file2' obj_mock_2.configure_mock(name=os.path.join(artifact_root_path, file_path_2), size=1) mock_populated_results = mock.MagicMock() mock_populated_results.__iter__.return_value = [obj_mock_1, obj_mock_2] mock_empty_results = mock.MagicMock() mock_empty_results.__iter__.return_value = [] def get_mock_listing(prefix, **kwargs): """ Produces a mock listing that only contains content if the specified prefix is the artifact root. This allows us to mock `list_artifacts` during the `_download_artifacts_into` subroutine without recursively listing the same artifacts at every level of the directory traversal. """ # pylint: disable=unused-argument prefix = os.path.join("/", prefix) if os.path.abspath(prefix) == os.path.abspath(artifact_root_path): return mock_populated_results else: return mock_empty_results def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") gcs_mock.Client.return_value.get_bucket.return_value\ .list_blobs.side_effect = get_mock_listing gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\ .download_to_filename.side_effect = mkfile # Ensure that the root directory can be downloaded successfully repo.download_artifacts("") # Ensure that the `mkfile` side effect copied all of the download artifacts into `tmpdir` dir_contents = os.listdir(tmpdir.strpath) assert file_path_1 in dir_contents assert file_path_2 in dir_contents
def test_download_artifacts_calls_expected_gcs_client_methods( gcs_mock, tmpdir): repo = GCSArtifactRepository("gs://test_bucket/some/path", gcs_mock) def mkfile(fname): fname = os.path.basename(fname) f = tmpdir.join(fname) f.write("hello world!") gcs_mock.Client.return_value.get_bucket.return_value.get_blob.return_value\ .download_to_filename.side_effect = mkfile repo.download_artifacts("test.txt") assert os.path.exists(os.path.join(tmpdir.strpath, "test.txt")) gcs_mock.Client().get_bucket.assert_called_with('test_bucket') gcs_mock.Client().get_bucket().get_blob\ .assert_called_with('some/path/test.txt') download_calls = \ gcs_mock.Client().get_bucket().get_blob().download_to_filename.call_args_list assert len(download_calls) == 1 download_path_arg = download_calls[0][0][0] assert "test.txt" in download_path_arg