def test_gcs_client_warpper(): bucket_name = "fake_bucket" object_key = "path/to/object" localfile = "path/to/local/file" mocked_gcs_client = MagicMock() with patch( "datasetinsights.io.gcs.Client", MagicMock(return_value=mocked_gcs_client), ): client = GCSClient() mocked_bucket = MagicMock() mocked_blob = MagicMock() mocked_gcs_client.get_bucket = MagicMock(return_value=mocked_bucket) mocked_bucket.blob = MagicMock(return_value=mocked_blob) mocked_blob.download_to_filename = MagicMock() client.download(bucket_name, object_key, localfile) mocked_gcs_client.get_bucket.assert_called_with(bucket_name) mocked_bucket.blob.assert_called_with(object_key) mocked_blob.download_to_filename.assert_called_with(localfile) mocked_blob.upload_from_filename = MagicMock() client.upload(localfile, bucket_name, object_key) mocked_blob.upload_from_filename.assert_called_with(localfile)
def test_gcs_client_upload_file_bucket_key(mock_isdir, mock_upload_file): localfile = local_path + file_name mocked_gcs_client = MagicMock() mock_isdir.return_value = False with patch( "datasetinsights.io.gcs.Client", MagicMock(return_value=mocked_gcs_client), ): client = GCSClient() client.upload(local_path=localfile, bucket=bucket_name, key=base_key) mock_upload_file.assert_called_with( bucket=mocked_gcs_client.get_bucket(), key=base_key, local_path=localfile, )
def test_gcs_client_upload_folder_url(mock_isdir, mock_upload_folder): mocked_gcs_client = MagicMock() mock_isdir.return_value = True url = base_url with patch( "datasetinsights.io.gcs.Client", MagicMock(return_value=mocked_gcs_client), ): client = GCSClient() client.upload(local_path=local_path, url=url, pattern="*") mock_upload_folder.assert_called_with( bucket=mocked_gcs_client.get_bucket(), key=base_key, local_path=local_path, pattern="*", )
class GCSEstimatorWriter: """Writes (saves) estimator checkpoints on GCS. Args: cloud_path (str): GCS cloud path (e.g. gs://bucket/path/to/directoy) prefix (str): filename prefix of the checkpoint files suffix (str): filename suffix of the checkpoint files """ def __init__(self, cloud_path, prefix, *, suffix=DEFAULT_SUFFIX): self._tempdir = tempfile.TemporaryDirectory().name self._client = GCSClient() self._bucket, self._gcs_path = gcs_bucket_and_path(cloud_path) self._writer = LocalEstimatorWriter(self._tempdir, prefix, create_dir=True, suffix=suffix) def save(self, estimator, epoch=None): """Save estimator to checkpoint files on GCS. Args: estimator (datasetinsights.estimators.Estimator): datasetinsights estimator object. epoch (int): the current epoch number. Default: None Returns: Full GCS cloud path to the saved checkpoint file. """ path = self._writer.save(estimator, epoch) filename = os.path.basename(path) object_key = os.path.join(self._gcs_path, filename) full_cloud_path = f"gs://{self._bucket}/{object_key}" logger.debug(f"Copying estimator from {path} to {full_cloud_path}") self._client.upload(path, self._bucket, object_key) return full_cloud_path