def from_model_bundle(model_bundle_uri: str, tmp_dir: str): """Create a Learner from a model bundle.""" model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir) model_bundle_dir = join(tmp_dir, 'model-bundle') unzip(model_bundle_path, model_bundle_dir) config_path = join(model_bundle_dir, 'pipeline-config.json') model_path = join(model_bundle_dir, 'model.pth') config_dict = file_to_json(config_path) config_dict = upgrade_config(config_dict) cfg = build_config(config_dict) hub_dir = join(model_bundle_dir, MODULES_DIRNAME) model_def_path = None loss_def_path = None # retrieve existing model definition, if available ext_cfg = cfg.learner.model.external_def if ext_cfg is not None: model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir) log.info( f'Using model definition found in bundle: {model_def_path}') # retrieve existing loss function definition, if available ext_cfg = cfg.learner.solver.external_loss_def if ext_cfg is not None: loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir) log.info(f'Using loss definition found in bundle: {loss_def_path}') return cfg.learner.build(tmp_dir=tmp_dir, model_path=model_path, model_def_path=model_def_path, loss_def_path=loss_def_path)
def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str, tmp_dir: str, *args, **kwargs) -> Any: """Load an entrypoint from: - a local uri of a zip file, or - a local uri of a directory, or - a remote uri of zip file. The zip file should either have hubconf.py at the top level or contain a single sub-directory that contains hubconf.py at its top level. In the latter case, the sub-directory will be copied to hubconf_dir. Args: uri (str): A URI. hubconf_dir (str): The target directory where the contents from the uri will finally be saved to. entrypoint (str): Name of a callable present in hubconf.py. tmp_dir (str): Directory where the zip file will be downloaded to and initially extracted. *args: Args to be passed to the entrypoint. **kwargs: Keyword args to be passed to the entrypoint. Returns: Any: The output from calling the entrypoint. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' if is_zip: # unzip zip_path = download_if_needed(uri, tmp_dir) unzip_dir = join(tmp_dir, uri_path.stem) _remove_dir(unzip_dir) unzip(zip_path, target_dir=unzip_dir) unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False)) _remove_dir(hubconf_dir) # if the top level only contains a directory if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]): sub_dir = unzipped_contents[0] shutil.move(sub_dir, hubconf_dir) else: shutil.move(unzip_dir, hubconf_dir) _remove_dir(unzip_dir) # assume uri is local and attempt copying else: # only copy if needed if not samefile(uri, hubconf_dir): _remove_dir(hubconf_dir) shutil.copytree(uri, hubconf_dir) out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs) return out
def from_model_bundle(model_bundle_uri: str, tmp_dir: str): """Create a Learner from a model bundle.""" model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir) model_bundle_dir = join(tmp_dir, 'model-bundle') unzip(model_bundle_path, model_bundle_dir) config_path = join(model_bundle_dir, 'pipeline-config.json') model_path = join(model_bundle_dir, 'model.pth') config_dict = file_to_json(config_path) config_dict = upgrade_config(config_dict) cfg = build_config(config_dict) return cfg.learner.build(tmp_dir, model_path=model_path)
def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]: """Parse the contents of a STAC catalog (downloading it first, if remote). If the uri is a zip file, unzip it, find catalog.json inside it and parse that. Args: uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip file containing a STAC catalog JSON file. Raises: FileNotFoundError: If catalog.json is not found inside the zip file. Exception: If multiple catalog.json's are found inside the zip file. Returns: List[dict]: A lsit of dicts with keys: "label_uri", "image_uris", "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry". Each dict corresponds to one label item and its associated image assets in the STAC catalog. """ uri_path = Path(uri) is_zip = uri_path.suffix.lower() == '.zip' with TemporaryDirectory() as tmp_dir: catalog_path = download_if_needed(uri, tmp_dir) if not is_zip: return parse_stac(catalog_path) if unzip_dir is None: raise ValueError( f'uri ("{uri}") is a zip file, but no unzip_dir provided.') zip_path = catalog_path unzip(zip_path, target_dir=unzip_dir) catalog_paths = list(Path(unzip_dir).glob('**/catalog.json')) if len(catalog_paths) == 0: raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.') elif len(catalog_paths) > 1: raise Exception(f'More than one "catalog.json" found in ' f'{uri}.') catalog_path = str(catalog_paths[0]) return parse_stac(catalog_path)