def test_copy_from_http(self): http_path = ('https://raw.githubusercontent.com/tensorflow/models/' '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md') expected = os.path.join( self.tmp_dir.name, 'http', 'raw.githubusercontent.com', 'tensorflow/models', '17fa52864bfc7a7444a8b921d8a8eb1669e14ebd/README.md') download_if_needed(http_path, self.tmp_dir.name) self.assertTrue(file_exists(expected)) os.remove(expected)
def load_init_weights(self): """Load the weights to initialize model.""" if self.cfg.model.init_weights: weights_path = download_if_needed(self.cfg.model.init_weights, self.tmp_dir) self.model.load_state_dict( torch.load(weights_path, map_location=self.device))
def unzip_data(self, uri: Union[str, List[str]]) -> List[str]: """Unzip dataset zip files. Args: uri: a list of URIs of zip files or the URI of a directory containing zip files Returns: paths to directories that each contain contents of one zip file """ data_dirs = [] if isinstance(uri, list): zip_uris = uri else: zip_uris = ([uri] if uri.endswith('.zip') else list_paths( uri, 'zip')) for zip_ind, zip_uri in enumerate(zip_uris): zip_path = get_local_path(zip_uri, self.data_cache_dir) if not isfile(zip_path): zip_path = download_if_needed(zip_uri, self.data_cache_dir) with zipfile.ZipFile(zip_path, 'r') as zipf: data_dir = join(self.tmp_dir, 'data', str(uuid.uuid4()), str(zip_ind)) data_dirs.append(data_dir) zipf.extractall(data_dir) return data_dirs
def from_model_bundle(model_bundle_uri: str, tmp_dir: str): """Create a Learner from a model bundle.""" model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir) model_bundle_dir = join(tmp_dir, 'model-bundle') unzip(model_bundle_path, model_bundle_dir) config_path = join(model_bundle_dir, 'pipeline-config.json') model_path = join(model_bundle_dir, 'model.pth') config_dict = file_to_json(config_path) config_dict = upgrade_config(config_dict) cfg = build_config(config_dict) hub_dir = join(model_bundle_dir, MODULES_DIRNAME) model_def_path = None loss_def_path = None # retrieve existing model definition, if available ext_cfg = cfg.learner.model.external_def if ext_cfg is not None: model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir) log.info( f'Using model definition found in bundle: {model_def_path}') # retrieve existing loss function definition, if available ext_cfg = cfg.learner.solver.external_loss_def if ext_cfg is not None: loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir) log.info(f'Using loss definition found in bundle: {loss_def_path}') return cfg.learner.build(tmp_dir=tmp_dir, model_path=model_path, model_def_path=model_def_path, loss_def_path=loss_def_path)
def test_download_if_needed_local(self): with self.assertRaises(NotReadableError): file_to_str(self.local_path) str_to_file(self.content_str, self.local_path) upload_or_copy(self.local_path, self.local_path) local_path = download_if_needed(self.local_path, self.tmp_dir.name) self.assertEqual(local_path, self.local_path)
def _download_data(self, tmp_dir): """Download any data needed for this Raster Source. Return a single local path representing the image or a VRT of the data. """ if len(self.uris) == 1: return download_if_needed(self.uris[0], tmp_dir) else: return download_and_build_vrt(self.uris, tmp_dir)
def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str, tmp_dir: str, *args, **kwargs) -> Any: """Load an entrypoint from: - a local uri of a zip file, or - a local uri of a directory, or - a remote uri of zip file. The zip file should either have hubconf.py at the top level or contain a single sub-directory that contains hubconf.py at its top level. In the latter case, the sub-directory will be copied to hubconf_dir. Args: uri (str): A URI. hubconf_dir (str): The target directory where the contents from the uri will finally be saved to. entrypoint (str): Name of a callable present in hubconf.py. tmp_dir (str): Directory where the zip file will be downloaded to and initially extracted. *args: Args to be passed to the entrypoint. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' if is_zip: # unzip zip_path = download_if_needed(uri, tmp_dir) unzip_dir = join(tmp_dir, uri_path.stem) _remove_dir(unzip_dir) unzip(zip_path, target_dir=unzip_dir) unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False)) _remove_dir(hubconf_dir) # if the top level only contains a directory if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]): sub_dir = unzipped_contents[0] shutil.move(sub_dir, hubconf_dir) else: shutil.move(unzip_dir, hubconf_dir) _remove_dir(unzip_dir) # assume uri is local and attempt copying else: # only copy if needed if not samefile(uri, hubconf_dir): _remove_dir(hubconf_dir) shutil.copytree(uri, hubconf_dir) out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs) return out
def test_download_if_needed_s3(self): with self.assertRaises(NotReadableError): file_to_str(self.s3_path) str_to_file(self.content_str, self.local_path) upload_or_copy(self.local_path, self.s3_path) local_path = download_if_needed(self.s3_path, self.tmp_dir.name) content_str = file_to_str(local_path) self.assertEqual(self.content_str, content_str) wrong_path = 's3://wrongpath/x.txt' with self.assertRaises(NotWritableError): upload_or_copy(local_path, wrong_path)
def from_model_bundle(model_bundle_uri: str, tmp_dir: str): """Create a Learner from a model bundle.""" model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir) model_bundle_dir = join(tmp_dir, 'model-bundle') unzip(model_bundle_path, model_bundle_dir) config_path = join(model_bundle_dir, 'pipeline-config.json') model_path = join(model_bundle_dir, 'model.pth') config_dict = file_to_json(config_path) config_dict = upgrade_config(config_dict) cfg = build_config(config_dict) return cfg.learner.build(tmp_dir, model_path=model_path)
def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]: """Parse the contents of a STAC catalog (downloading it first, if remote). If the uri is a zip file, unzip it, find catalog.json inside it and parse that. Args: uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip file containing a STAC catalog JSON file. Raises: FileNotFoundError: If catalog.json is not found inside the zip file. Exception: If multiple catalog.json's are found inside the zip file. Returns: List[dict]: A lsit of dicts with keys: "label_uri", "image_uris", "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry". Each dict corresponds to one label item and its associated image assets in the STAC catalog. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' with TemporaryDirectory() as tmp_dir: catalog_path = download_if_needed(uri, tmp_dir) if not is_zip: return parse_stac(catalog_path) if unzip_dir is None: raise ValueError( f'uri ("{uri}") is a zip file, but no unzip_dir provided.') zip_path = catalog_path unzip(zip_path, target_dir=unzip_dir) catalog_paths = list(Path(unzip_dir).glob('**/catalog.json')) if len(catalog_paths) == 0: raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.') elif len(catalog_paths) > 1: raise Exception(f'More than one "catalog.json" found in ' f'{uri}.') catalog_path = str(catalog_paths[0]) return parse_stac(catalog_path)
def download_and_build_vrt(image_uris, tmp_dir): log.info('Building VRT...') image_paths = [download_if_needed(uri, tmp_dir) for uri in image_uris] image_path = os.path.join(tmp_dir, 'index.vrt') build_vrt(image_path, image_paths) return image_path
def _zxy2geotiff(tile_schema, zoom, bounds, output_uri, make_cog=False): """Generates a GeoTIFF of a bounded region from a ZXY tile server. Args: tile_schema: (str) the URI schema for zxy tiles (ie. a slippy map tile server) of the form /tileserver-uri/{z}/{x}/{y}.png. If {-y} is used, the tiles are assumed to be indexed using TMS coordinates, where the y axis starts at the southernmost point. The URI can be for http, S3, or the local file system. zoom: (int) the zoom level to use when retrieving tiles bounds: (list) a list of length 4 containing min_lat, min_lng, max_lat, max_lng output_uri: (str) where to save the GeoTIFF. The URI can be for http, S3, or the local file system """ min_lat, min_lng, max_lat, max_lng = bounds if min_lat >= max_lat: raise ValueError('min_lat must be < max_lat') if min_lng >= max_lng: raise ValueError('min_lng must be < max_lng') is_tms = False if '{-y}' in tile_schema: tile_schema = tile_schema.replace('{-y}', '{y}') is_tms = True tmp_dir_obj = tempfile.TemporaryDirectory() tmp_dir = tmp_dir_obj.name # Get range of tiles that cover bounds. output_path = get_local_path(output_uri, tmp_dir) tile_sz = 256 t = mercantile.tile(min_lng, max_lat, zoom) xmin, ymin = t.x, t.y t = mercantile.tile(max_lng, min_lat, zoom) xmax, ymax = t.x, t.y # The supplied bounds are contained within the "tile bounds" -- ie. the # bounds of the set of tiles that covers the supplied bounds. Therefore, # we need to crop out the imagery that lies within the supplied bounds. # We do this by computing a top, bottom, left, and right offset in pixel # units of the supplied bounds against the tile bounds. Getting the offsets # in pixel units involves converting lng/lat to web mercator units since we # assume that is the CRS of the tiles. These offsets are then used to crop # individual tiles and place them correctly into the output raster. nw_merc_x, nw_merc_y = lnglat2merc(min_lng, max_lat) left_pix_offset, top_pix_offset = merc2pixel(xmin, ymin, zoom, nw_merc_x, nw_merc_y) se_merc_x, se_merc_y = lnglat2merc(max_lng, min_lat) se_left_pix_offset, se_top_pix_offset = merc2pixel(xmax, ymax, zoom, se_merc_x, se_merc_y) right_pix_offset = tile_sz - se_left_pix_offset bottom_pix_offset = tile_sz - se_top_pix_offset uncropped_height = tile_sz * (ymax - ymin + 1) uncropped_width = tile_sz * (xmax - xmin + 1) height = uncropped_height - top_pix_offset - bottom_pix_offset width = uncropped_width - left_pix_offset - right_pix_offset transform = rasterio.transform.from_bounds(nw_merc_x, se_merc_y, se_merc_x, nw_merc_y, width, height) with rasterio.open(output_path, 'w', driver='GTiff', height=height, width=width, count=3, crs='epsg:3857', transform=transform, dtype=rasterio.uint8) as dataset: out_x = 0 for xi, x in enumerate(range(xmin, xmax + 1)): tile_xmin, tile_xmax = 0, tile_sz - 1 if x == xmin: tile_xmin += left_pix_offset if x == xmax: tile_xmax -= right_pix_offset window_width = tile_xmax - tile_xmin + 1 out_y = 0 for yi, y in enumerate(range(ymin, ymax + 1)): tile_ymin, tile_ymax = 0, tile_sz - 1 if y == ymin: tile_ymin += top_pix_offset if y == ymax: tile_ymax -= bottom_pix_offset window_height = tile_ymax - tile_ymin + 1 # Convert from xyz to tms if needed. # https://gist.github.com/tmcw/4954720 if is_tms: y = (2**zoom) - y - 1 tile_uri = tile_schema.format(x=x, y=y, z=zoom) tile_path = download_if_needed(tile_uri, tmp_dir) img = np.array(Image.open(tile_path)) img = img[tile_ymin:tile_ymax + 1, tile_xmin:tile_xmax + 1, :] window = Window(out_x, out_y, window_width, window_height) dataset.write(np.transpose(img[:, :, 0:3], (2, 0, 1)), window=window) out_y += window_height out_x += window_width if make_cog: create_cog(output_path, output_uri, tmp_dir) else: upload_or_copy(output_path, output_uri)