예제 #1
0
def test_gcs_client_warpper():
    bucket_name = "fake_bucket"
    object_key = "path/to/object"
    localfile = "path/to/local/file"

    mocked_gcs_client = MagicMock()
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        mocked_bucket = MagicMock()
        mocked_blob = MagicMock()
        mocked_gcs_client.get_bucket = MagicMock(return_value=mocked_bucket)
        mocked_bucket.blob = MagicMock(return_value=mocked_blob)

        mocked_blob.download_to_filename = MagicMock()
        client.download(bucket_name, object_key, localfile)
        mocked_gcs_client.get_bucket.assert_called_with(bucket_name)
        mocked_bucket.blob.assert_called_with(object_key)
        mocked_blob.download_to_filename.assert_called_with(localfile)

        mocked_blob.upload_from_filename = MagicMock()
        client.upload(localfile, bucket_name, object_key)
        mocked_blob.upload_from_filename.assert_called_with(localfile)
예제 #2
0
 def download(self, cloud_path=COCO_GCS_PATH):
     path = Path(self.root)
     path.mkdir(parents=True, exist_ok=True)
     client = GCSClient()
     annotations_zip_gcs = f"{cloud_path}/annotations_trainval2017.zip"
     annotations_zip_2017 = self._get_local_annotations_zip()
     logger.info(f"checking for local copy of data")
     if not os.path.exists(annotations_zip_2017):
         logger.info(f"no annotations zip file found, will download.")
         client.download(
             bucket_name=const.GCS_BUCKET,
             object_key=annotations_zip_gcs,
             localfile=annotations_zip_2017,
         )
         with zipfile.ZipFile(annotations_zip_2017, "r") as zip_dir:
             zip_dir.extractall(self.root)
     images_local = self._get_local_images_zip()
     images_gcs = f"{cloud_path}/{self.split}2017.zip"
     if not os.path.exists(images_local):
         logger.info(f"no zip file for images for {self.split} found,"
                     f" will download")
         client.download(
             bucket_name=const.GCS_BUCKET,
             object_key=images_gcs,
             localfile=images_local,
         )
         with zipfile.ZipFile(images_local, "r") as zip_dir:
             zip_dir.extractall(self.root)
예제 #3
0
    def download(self, cloud_path):
        """Download cityscapes dataset

        Note:
            The current implementation assumes a GCS cloud path.
            Should we keep this method here if we want to support other cloud
            storage system?

        Args:
            cloud_path (str): cloud path of the dataset
        """
        path = Path(self.root)
        path.mkdir(parents=True, exist_ok=True)

        for zipfile in ZIPFILES:
            localfile = os.path.join(self.root, zipfile)
            if os.path.isfile(localfile):
                # TODO: Check file hash to verify file integrity
                logger.debug(f"File {localfile} exists. Skip download.")
                continue
            client = GCSClient()
            object_key = os.path.join(CITYSCAPES_GCS_PATH, zipfile)

            logger.debug(
                f"Downloading file {localfile} from gs://{const.GCS_BUCKET}/"
                f"{object_key}")
            client.download(const.GCS_BUCKET, object_key, localfile)
예제 #4
0
def test_gcs_client_download_folder_url(mock_download_folder, mock_is_file):
    mocked_gcs_client = MagicMock()
    mock_is_file.return_value = False
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        client.download(local_path=local_path, url=base_url)
        mock_download_folder.assert_called_with(mocked_gcs_client.get_bucket(),
                                                base_key, local_path)
예제 #5
0
    def download(self, cloud_path):
        """Download nyu_v2 dataset
        The directory structure of the downloaded data is
        |--self.root
           |--nyudepth
               |--nyu_data.zip
               |--data
                   |--nyu2_test.csv
                   |--nyu2_test
                         |--00000_colors.png
                         |--00000_depth.png ...
                         |--01448_colors.png
                         |--01448_depth.png
                   |--nyu2_train.csv
                   |--nyu2_train
                         |--basement_0001a_out
                              |--1.jpg
                              |--1.png ...
                              |--281.jpg
                              |--281.png
                         ...
                         |--study_room_0005b_out
                              |--1.jpg
                              |--1.png ...
                              |--133.jpg
                              |--133.png
        Args:
            cloud_path (str): cloud path of the dataset
        """
        zip_file = os.path.join(self.root, ZIPFILE)
        unzip_dir = os.path.join(self.root, UNZIP_NAME)

        if os.path.isfile(zip_file):
            logger.debug(f"File {zip_file} exists. Skip download.")
        else:
            client = GCSClient()
            object_key = os.path.join(NYU_GCS_PATH, ZIPFILE)

            logger.debug(
                f"Downloading file {zip_file} from gs://{const.GCS_BUCKET}/"
                f"{object_key}")
            client.download(
                local_path=self.root,
                bucket=const.GCS_BUCKET,
                key=object_key,
            )

        if os.path.isdir(unzip_dir):
            logger.debug(f"File {unzip_dir} exists. Skip unzip.")
        else:
            # unzip the file
            with ZipFile(zip_file, "r") as zip_ref:
                zip_ref.extractall(self.root)
                logger.debug(f"Unzip file from {zip_file}")
예제 #6
0
def test_gcs_client_download_file_bucket_key(mock_download_file, mock_is_file):
    mocked_gcs_client = MagicMock()
    mock_is_file.return_value = True
    object_key = base_key + file_name
    with patch(
            "datasetinsights.io.gcs.Client",
            MagicMock(return_value=mocked_gcs_client),
    ):
        client = GCSClient()
        client.download(local_path=local_path,
                        bucket=bucket_name,
                        key=object_key)
        mock_download_file.assert_called_with(mocked_gcs_client.get_bucket(),
                                              object_key, local_path)
예제 #7
0
def load_from_gcs(estimator, full_cloud_path):
    """Load estimator from checkpoint files on GCS.

    Args:
        estimator (datasetinsights.estimators.Estimator):
            datasetinsights estimator object.
        full_cloud_path: full path to the checkpoint file

    """
    filename = os.path.basename(full_cloud_path)
    with tempfile.TemporaryDirectory() as temp_dir:
        path = os.path.join(temp_dir, filename)
        logger.debug(f"Downloading estimator from {full_cloud_path} to {path}")
        client = GCSClient()
        client.download(local_path=temp_dir, url=full_cloud_path)
        estimator.load(path)
예제 #8
0
    def download(self):
        """Download dataset from GCS
        """
        cloud_path = f"gs://{const.GCS_BUCKET}/{self.GCS_PATH}"
        client = GCSClient()
        # download label file
        label_zip = self.LABEL_ZIP
        client.download(url=cloud_path, local_path=self.root)
        with zipfile.ZipFile(label_zip, "r") as zip_dir:
            zip_dir.extractall(self.root)

        # download tfexamples for a dataset split
        tfexamples_zip = self.SPLITS_ZIP.get(self.split)
        client.download(url=cloud_path, local_path=self.root)
        with zipfile.ZipFile(tfexamples_zip, "r") as zip_dir:
            zip_dir.extractall(self.root)
예제 #9
0
class GCSDatasetDownloader(DatasetDownloader, protocol="gs://"):
    """ This class is used to download data from GCS
    """
    def __init__(self, **kwargs):
        """ initiating GCSDownloader
        """
        self.client = GCSClient()

    def download(self, source_uri=None, output=None, **kwargs):
        """

        Args:
            source_uri: This is the downloader-uri that indicates where on
                GCS the dataset should be downloaded from.
                The expected source-uri follows these patterns
                gs://bucket/folder or gs://bucket/folder/data.zip

            output: This is the path to the directory
                where the download will store the dataset.
        """
        self.client.download(local_path=output, url=source_uri)