def test_copy_from_http(self):
        http_path = ('https://raw.githubusercontent.com/tensorflow/models/'
                     '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md')
        expected = os.path.join(
            self.tmp_dir.name, 'http', 'raw.githubusercontent.com',
            'tensorflow/models',
            '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md')
        download_if_needed(http_path, self.tmp_dir.name)

        self.assertTrue(file_exists(expected))
        os.remove(expected)
예제 #2
0
 def load_init_weights(self):
     """Load the weights to initialize model."""
     if self.cfg.model.init_weights:
         weights_path = download_if_needed(self.cfg.model.init_weights,
                                           self.tmp_dir)
         self.model.load_state_dict(
             torch.load(weights_path, map_location=self.device))
예제 #3
0
    def unzip_data(self, uri: Union[str, List[str]]) -> List[str]:
        """Unzip dataset zip files.

        Args:
            uri: a list of URIs of zip files or the URI of a directory containing
                zip files

        Returns:
            paths to directories that each contain contents of one zip file
        """
        data_dirs = []

        if isinstance(uri, list):
            zip_uris = uri
        else:
            zip_uris = ([uri] if uri.endswith('.zip') else list_paths(
                uri, 'zip'))

        for zip_ind, zip_uri in enumerate(zip_uris):
            zip_path = get_local_path(zip_uri, self.data_cache_dir)
            if not isfile(zip_path):
                zip_path = download_if_needed(zip_uri, self.data_cache_dir)
            with zipfile.ZipFile(zip_path, 'r') as zipf:
                data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()),
                                str(zip_ind))
                data_dirs.append(data_dir)
                zipf.extractall(data_dir)

        return data_dirs
예제 #4
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)

        hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
        model_def_path = None
        loss_def_path = None

        # retrieve existing model definition, if available
        ext_cfg = cfg.learner.model.external_def
        if ext_cfg is not None:
            model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(
                f'Using model definition found in bundle: {model_def_path}')

        # retrieve existing loss function definition, if available
        ext_cfg = cfg.learner.solver.external_loss_def
        if ext_cfg is not None:
            loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(f'Using loss definition found in bundle: {loss_def_path}')

        return cfg.learner.build(tmp_dir=tmp_dir,
                                 model_path=model_path,
                                 model_def_path=model_def_path,
                                 loss_def_path=loss_def_path)
    def test_download_if_needed_local(self):
        with self.assertRaises(NotReadableError):
            file_to_str(self.local_path)

        str_to_file(self.content_str, self.local_path)
        upload_or_copy(self.local_path, self.local_path)
        local_path = download_if_needed(self.local_path, self.tmp_dir.name)
        self.assertEqual(local_path, self.local_path)
예제 #6
0
    def _download_data(self, tmp_dir):
        """Download any data needed for this Raster Source.

        Return a single local path representing the image or a VRT of the data.
        """
        if len(self.uris) == 1:
            return download_if_needed(self.uris[0], tmp_dir)
        else:
            return download_and_build_vrt(self.uris, tmp_dir)
예제 #7
0
def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str,
                       tmp_dir: str, *args, **kwargs) -> Any:
    """Load an entrypoint from:
        - a local uri of a zip file, or
        - a local uri of a directory, or
        - a remote uri of zip file.

    The zip file should either have hubconf.py at the top level or contain
    a single sub-directory that contains hubconf.py at its top level. In the
    latter case, the sub-directory will be copied to hubconf_dir.

    Args:
        uri (str): A URI.
        hubconf_dir (str): The target directory where the contents from the uri
            will finally be saved to.
        entrypoint (str): Name of a callable present in hubconf.py.
        tmp_dir (str): Directory where the zip file will be downloaded to and
            initially extracted.
        *args: Args to be passed to the entrypoint.
        **kwargs: Keyword args to be passed to the entrypoint.

    Returns:
        Any: The output from calling the entrypoint.
    """

    uri_path = Path(uri)
    is_zip = uri_path.suffix.lower() == '.zip'
    if is_zip:
        # unzip
        zip_path = download_if_needed(uri, tmp_dir)
        unzip_dir = join(tmp_dir, uri_path.stem)
        _remove_dir(unzip_dir)
        unzip(zip_path, target_dir=unzip_dir)
        unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False))

        _remove_dir(hubconf_dir)

        # if the top level only contains a directory
        if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]):
            sub_dir = unzipped_contents[0]
            shutil.move(sub_dir, hubconf_dir)
        else:
            shutil.move(unzip_dir, hubconf_dir)

        _remove_dir(unzip_dir)
    # assume uri is local and attempt copying
    else:
        # only copy if needed
        if not samefile(uri, hubconf_dir):
            _remove_dir(hubconf_dir)
            shutil.copytree(uri, hubconf_dir)

    out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs)
    return out
    def test_download_if_needed_s3(self):
        with self.assertRaises(NotReadableError):
            file_to_str(self.s3_path)

        str_to_file(self.content_str, self.local_path)
        upload_or_copy(self.local_path, self.s3_path)
        local_path = download_if_needed(self.s3_path, self.tmp_dir.name)
        content_str = file_to_str(local_path)
        self.assertEqual(self.content_str, content_str)

        wrong_path = 's3://wrongpath/x.txt'
        with self.assertRaises(NotWritableError):
            upload_or_copy(local_path, wrong_path)
예제 #9
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)
        return cfg.learner.build(tmp_dir, model_path=model_path)
예제 #10
0
def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]:
    """Parse the contents of a STAC catalog (downloading it first, if
    remote). If the uri is a zip file, unzip it, find catalog.json inside it
    and parse that.

    Args:
        uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip
            file containing a STAC catalog JSON file.

    Raises:
        FileNotFoundError: If catalog.json is not found inside the zip file.
        Exception: If multiple catalog.json's are found inside the zip file.

    Returns:
        List[dict]: A lsit of dicts with keys: "label_uri", "image_uris",
            "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry".
            Each dict corresponds to one label item and its associated image
            assets in the STAC catalog.
    """
    uri_path = Path(uri)
    is_zip = uri_path.suffix.lower() == '.zip'

    with TemporaryDirectory() as tmp_dir:
        catalog_path = download_if_needed(uri, tmp_dir)
        if not is_zip:
            return parse_stac(catalog_path)
        if unzip_dir is None:
            raise ValueError(
                f'uri ("{uri}") is a zip file, but no unzip_dir provided.')
        zip_path = catalog_path
        unzip(zip_path, target_dir=unzip_dir)
        catalog_paths = list(Path(unzip_dir).glob('**/catalog.json'))
        if len(catalog_paths) == 0:
            raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.')
        elif len(catalog_paths) > 1:
            raise Exception(f'More than one "catalog.json" found in '
                            f'{uri}.')
        catalog_path = str(catalog_paths[0])
        return parse_stac(catalog_path)
예제 #11
0
def download_and_build_vrt(image_uris, tmp_dir):
    log.info('Building VRT...')
    image_paths = [download_if_needed(uri, tmp_dir) for uri in image_uris]
    image_path = os.path.join(tmp_dir, 'index.vrt')
    build_vrt(image_path, image_paths)
    return image_path
예제 #12
0
def _zxy2geotiff(tile_schema, zoom, bounds, output_uri, make_cog=False):
    """Generates a GeoTIFF of a bounded region from a ZXY tile server.

    Args:
        tile_schema: (str) the URI schema for zxy tiles (ie. a slippy map tile server)
            of the form /tileserver-uri/{z}/{x}/{y}.png. If {-y} is used, the tiles
            are assumed to be indexed using TMS coordinates, where the y axis starts
            at the southernmost point. The URI can be for http, S3, or the local
            file system.
        zoom: (int) the zoom level to use when retrieving tiles
        bounds: (list) a list of length 4 containing min_lat, min_lng,
            max_lat, max_lng
        output_uri: (str) where to save the GeoTIFF. The URI can be for http, S3, or the
            local file system
    """
    min_lat, min_lng, max_lat, max_lng = bounds
    if min_lat >= max_lat:
        raise ValueError('min_lat must be < max_lat')
    if min_lng >= max_lng:
        raise ValueError('min_lng must be < max_lng')

    is_tms = False
    if '{-y}' in tile_schema:
        tile_schema = tile_schema.replace('{-y}', '{y}')
        is_tms = True

    tmp_dir_obj = tempfile.TemporaryDirectory()
    tmp_dir = tmp_dir_obj.name

    # Get range of tiles that cover bounds.
    output_path = get_local_path(output_uri, tmp_dir)
    tile_sz = 256
    t = mercantile.tile(min_lng, max_lat, zoom)
    xmin, ymin = t.x, t.y
    t = mercantile.tile(max_lng, min_lat, zoom)
    xmax, ymax = t.x, t.y

    # The supplied bounds are contained within the "tile bounds" -- ie. the
    # bounds of the set of tiles that covers the supplied bounds. Therefore,
    # we need to crop out the imagery that lies within the supplied bounds.
    # We do this by computing a top, bottom, left, and right offset in pixel
    # units of the supplied bounds against the tile bounds. Getting the offsets
    # in pixel units involves converting lng/lat to web mercator units since we
    # assume that is the CRS of the tiles. These offsets are then used to crop
    # individual tiles and place them correctly into the output raster.
    nw_merc_x, nw_merc_y = lnglat2merc(min_lng, max_lat)
    left_pix_offset, top_pix_offset = merc2pixel(xmin, ymin, zoom, nw_merc_x,
                                                 nw_merc_y)

    se_merc_x, se_merc_y = lnglat2merc(max_lng, min_lat)
    se_left_pix_offset, se_top_pix_offset = merc2pixel(xmax, ymax, zoom,
                                                       se_merc_x, se_merc_y)
    right_pix_offset = tile_sz - se_left_pix_offset
    bottom_pix_offset = tile_sz - se_top_pix_offset

    uncropped_height = tile_sz * (ymax - ymin + 1)
    uncropped_width = tile_sz * (xmax - xmin + 1)
    height = uncropped_height - top_pix_offset - bottom_pix_offset
    width = uncropped_width - left_pix_offset - right_pix_offset

    transform = rasterio.transform.from_bounds(nw_merc_x, se_merc_y, se_merc_x,
                                               nw_merc_y, width, height)
    with rasterio.open(output_path,
                       'w',
                       driver='GTiff',
                       height=height,
                       width=width,
                       count=3,
                       crs='epsg:3857',
                       transform=transform,
                       dtype=rasterio.uint8) as dataset:
        out_x = 0
        for xi, x in enumerate(range(xmin, xmax + 1)):
            tile_xmin, tile_xmax = 0, tile_sz - 1
            if x == xmin:
                tile_xmin += left_pix_offset
            if x == xmax:
                tile_xmax -= right_pix_offset
            window_width = tile_xmax - tile_xmin + 1

            out_y = 0
            for yi, y in enumerate(range(ymin, ymax + 1)):
                tile_ymin, tile_ymax = 0, tile_sz - 1
                if y == ymin:
                    tile_ymin += top_pix_offset
                if y == ymax:
                    tile_ymax -= bottom_pix_offset
                window_height = tile_ymax - tile_ymin + 1

                # Convert from xyz to tms if needed.
                # https://gist.github.com/tmcw/4954720
                if is_tms:
                    y = (2**zoom) - y - 1
                tile_uri = tile_schema.format(x=x, y=y, z=zoom)
                tile_path = download_if_needed(tile_uri, tmp_dir)
                img = np.array(Image.open(tile_path))
                img = img[tile_ymin:tile_ymax + 1, tile_xmin:tile_xmax + 1, :]

                window = Window(out_x, out_y, window_width, window_height)
                dataset.write(np.transpose(img[:, :, 0:3], (2, 0, 1)),
                              window=window)
                out_y += window_height
            out_x += window_width

    if make_cog:
        create_cog(output_path, output_uri, tmp_dir)
    else:
        upload_or_copy(output_path, output_uri)