示例#1
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)

        hub_dir = join(model_bundle_dir, MODULES_DIRNAME)
        model_def_path = None
        loss_def_path = None

        # retrieve existing model definition, if available
        ext_cfg = cfg.learner.model.external_def
        if ext_cfg is not None:
            model_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(
                f'Using model definition found in bundle: {model_def_path}')

        # retrieve existing loss function definition, if available
        ext_cfg = cfg.learner.solver.external_loss_def
        if ext_cfg is not None:
            loss_def_path = get_hubconf_dir_from_cfg(ext_cfg, parent=hub_dir)
            log.info(f'Using loss definition found in bundle: {loss_def_path}')

        return cfg.learner.build(tmp_dir=tmp_dir,
                                 model_path=model_path,
                                 model_def_path=model_def_path,
                                 loss_def_path=loss_def_path)
示例#2
0
def torch_hub_load_uri(uri: str, hubconf_dir: str, entrypoint: str,
                       tmp_dir: str, *args, **kwargs) -> Any:
    """Load an entrypoint from:
        - a local uri of a zip file, or
        - a local uri of a directory, or
        - a remote uri of zip file.

    The zip file should either have hubconf.py at the top level or contain
    a single sub-directory that contains hubconf.py at its top level. In the
    latter case, the sub-directory will be copied to hubconf_dir.

    Args:
        uri (str): A URI.
        hubconf_dir (str): The target directory where the contents from the uri
            will finally be saved to.
        entrypoint (str): Name of a callable present in hubconf.py.
        tmp_dir (str): Directory where the zip file will be downloaded to and
            initially extracted.
        *args: Args to be passed to the entrypoint.
        **kwargs: Keyword args to be passed to the entrypoint.

    Returns:
        Any: The output from calling the entrypoint.
    """

    uri_path = Path(uri)
    is_zip = uri_path.suffix.lower() == '.zip'
    if is_zip:
        # unzip
        zip_path = download_if_needed(uri, tmp_dir)
        unzip_dir = join(tmp_dir, uri_path.stem)
        _remove_dir(unzip_dir)
        unzip(zip_path, target_dir=unzip_dir)
        unzipped_contents = list(glob(f'{unzip_dir}/*', recursive=False))

        _remove_dir(hubconf_dir)

        # if the top level only contains a directory
        if (len(unzipped_contents) == 1) and isdir(unzipped_contents[0]):
            sub_dir = unzipped_contents[0]
            shutil.move(sub_dir, hubconf_dir)
        else:
            shutil.move(unzip_dir, hubconf_dir)

        _remove_dir(unzip_dir)
    # assume uri is local and attempt copying
    else:
        # only copy if needed
        if not samefile(uri, hubconf_dir):
            _remove_dir(hubconf_dir)
            shutil.copytree(uri, hubconf_dir)

    out = torch_hub_load_local(hubconf_dir, entrypoint, *args, **kwargs)
    return out
示例#3
0
    def from_model_bundle(model_bundle_uri: str, tmp_dir: str):
        """Create a Learner from a model bundle."""
        model_bundle_path = download_if_needed(model_bundle_uri, tmp_dir)
        model_bundle_dir = join(tmp_dir, 'model-bundle')
        unzip(model_bundle_path, model_bundle_dir)

        config_path = join(model_bundle_dir, 'pipeline-config.json')
        model_path = join(model_bundle_dir, 'model.pth')

        config_dict = file_to_json(config_path)
        config_dict = upgrade_config(config_dict)

        cfg = build_config(config_dict)
        return cfg.learner.build(tmp_dir, model_path=model_path)
示例#4
0
def read_stac(uri: str, unzip_dir: Optional[str] = None) -> List[dict]:
    """Parse the contents of a STAC catalog (downloading it first, if
    remote). If the uri is a zip file, unzip it, find catalog.json inside it
    and parse that.

    Args:
        uri (str): Either a URI to a STAC catalog JSON file or a URI to a zip
            file containing a STAC catalog JSON file.

    Raises:
        FileNotFoundError: If catalog.json is not found inside the zip file.
        Exception: If multiple catalog.json's are found inside the zip file.

    Returns:
        List[dict]: A lsit of dicts with keys: "label_uri", "image_uris",
            "label_bbox", "image_bbox", "bboxes_intersect", and "aoi_geometry".
            Each dict corresponds to one label item and its associated image
            assets in the STAC catalog.
    """
    uri_path = Path(uri)
    is_zip = uri_path.suffix.lower() == '.zip'

    with TemporaryDirectory() as tmp_dir:
        catalog_path = download_if_needed(uri, tmp_dir)
        if not is_zip:
            return parse_stac(catalog_path)
        if unzip_dir is None:
            raise ValueError(
                f'uri ("{uri}") is a zip file, but no unzip_dir provided.')
        zip_path = catalog_path
        unzip(zip_path, target_dir=unzip_dir)
        catalog_paths = list(Path(unzip_dir).glob('**/catalog.json'))
        if len(catalog_paths) == 0:
            raise FileNotFoundError(f'Unable to find "catalog.json" in {uri}.')
        elif len(catalog_paths) > 1:
            raise Exception(f'More than one "catalog.json" found in '
                            f'{uri}.')
        catalog_path = str(catalog_paths[0])
        return parse_stac(catalog_path)