示例#1
0
def test_gcs_client_warpper():
    bucket_name = "fake_bucket"
    object_key = "path/to/object"
    localfile = "path/to/local/file"

    mocked_gcs_client = MagicMock()
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        mocked_bucket = MagicMock()
        mocked_blob = MagicMock()
        mocked_gcs_client.get_bucket = MagicMock(return_value=mocked_bucket)
        mocked_bucket.blob = MagicMock(return_value=mocked_blob)

        mocked_blob.download_to_filename = MagicMock()
        client.download(bucket_name, object_key, localfile)
        mocked_gcs_client.get_bucket.assert_called_with(bucket_name)
        mocked_bucket.blob.assert_called_with(object_key)
        mocked_blob.download_to_filename.assert_called_with(localfile)

        mocked_blob.upload_from_filename = MagicMock()
        client.upload(localfile, bucket_name, object_key)
        mocked_blob.upload_from_filename.assert_called_with(localfile)
示例#2
0
def test_gcs_client_upload_file_bucket_key(mock_isdir, mock_upload_file):
    localfile = local_path + file_name
    mocked_gcs_client = MagicMock()
    mock_isdir.return_value = False
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        client.upload(local_path=localfile, bucket=bucket_name, key=base_key)
        mock_upload_file.assert_called_with(
            bucket=mocked_gcs_client.get_bucket(),
            key=base_key,
            local_path=localfile,
        )
示例#3
0
def test_gcs_client_upload_folder_url(mock_isdir, mock_upload_folder):
    mocked_gcs_client = MagicMock()
    mock_isdir.return_value = True
    url = base_url
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        client.upload(local_path=local_path, url=url, pattern="*")
        mock_upload_folder.assert_called_with(
            bucket=mocked_gcs_client.get_bucket(),
            key=base_key,
            local_path=local_path,
            pattern="*",
        )
示例#4
0
class GCSEstimatorWriter:
    """Writes (saves) estimator checkpoints on GCS.

    Args:
        cloud_path (str): GCS cloud path (e.g. gs://bucket/path/to/directoy)
        prefix (str): filename prefix of the checkpoint files
        suffix (str): filename suffix of the checkpoint files

    """
    def __init__(self, cloud_path, prefix, *, suffix=DEFAULT_SUFFIX):
        self._tempdir = tempfile.TemporaryDirectory().name
        self._client = GCSClient()
        self._bucket, self._gcs_path = gcs_bucket_and_path(cloud_path)
        self._writer = LocalEstimatorWriter(self._tempdir,
                                            prefix,
                                            create_dir=True,
                                            suffix=suffix)

    def save(self, estimator, epoch=None):
        """Save estimator to checkpoint files on GCS.

        Args:
            estimator (datasetinsights.estimators.Estimator):
                datasetinsights estimator object.
            epoch (int): the current epoch number. Default: None

        Returns:
            Full GCS cloud path to the saved checkpoint file.

        """
        path = self._writer.save(estimator, epoch)
        filename = os.path.basename(path)
        object_key = os.path.join(self._gcs_path, filename)

        full_cloud_path = f"gs://{self._bucket}/{object_key}"

        logger.debug(f"Copying estimator from {path} to {full_cloud_path}")
        self._client.upload(path, self._bucket, object_key)

        return full_cloud_path