Ejemplo n.º 1
0
    def test_resolve(self, configs, expected_id, output_type):
        locator = ComponentLocator()
        resolver = ReferenceResolver()
        # add items to resolver
        for k, v in configs.items():
            if ConfigComponent.is_instantiable(v):
                resolver.add_item(
                    ConfigComponent(config=v, id=k, locator=locator))
            elif ConfigExpression.is_expression(v):
                resolver.add_item(
                    ConfigExpression(config=v,
                                     id=k,
                                     globals={
                                         "monai": monai,
                                         "torch": torch
                                     }))
            else:
                resolver.add_item(ConfigItem(config=v, id=k))

        result = resolver.get_resolved_content(
            expected_id)  # the root id is `expected_id` here
        self.assertTrue(isinstance(result, output_type))

        # test lazy instantiation
        item = resolver.get_item(expected_id, resolve=True)
        config = item.get_config()
        config["_disabled_"] = False
        item.update_config(config=config)
        if isinstance(item, ConfigComponent):
            result = item.instantiate()
        else:
            result = item.get_config()
        self.assertTrue(isinstance(result, output_type))
Ejemplo n.º 2
0
    def _do_parse(self, config, id: str = ""):
        """
        Recursively parse the nested data in config source, add every item as `ConfigItem` to the resolver.

        Args:
            config: config source to parse.
            id: id of the ``ConfigItem``, ``"#"`` in id are interpreted as special characters to
                go one level further into the nested structures.
                Use digits indexing from "0" for list or other strings for dict.
                For example: ``"xform#5"``, ``"net#channels"``. ``""`` indicates the entire ``self.config``.

        """
        if isinstance(config, (dict, list)):
            for k, v in enumerate(config) if isinstance(config, list) else config.items():
                sub_id = f"{id}{ID_SEP_KEY}{k}" if id != "" else k
                self._do_parse(config=v, id=sub_id)

        # copy every config item to make them independent and add them to the resolver
        item_conf = deepcopy(config)
        if ConfigComponent.is_instantiable(item_conf):
            self.ref_resolver.add_item(ConfigComponent(config=item_conf, id=id, locator=self.locator))
        elif ConfigExpression.is_expression(item_conf):
            self.ref_resolver.add_item(ConfigExpression(config=item_conf, id=id, globals=self.globals))
        else:
            self.ref_resolver.add_item(ConfigItem(config=item_conf, id=id))
Ejemplo n.º 3
0
    def update_config_with_refs(cls, config, id: str, refs: Optional[Dict] = None):
        """
        With all the references in ``refs``, update the input config content with references
        and return the new config.

        Args:
            config: input config content to update.
            id: ID name for the input config.
            refs: all the referring content with ids, default to `None`.

        """
        refs_: Dict = refs or {}
        if isinstance(config, str):
            return cls.update_refs_pattern(config, refs_)
        if not isinstance(config, (list, dict)):
            return config
        ret = type(config)()
        for idx, v in config.items() if isinstance(config, dict) else enumerate(config):
            sub_id = f"{id}{cls.sep}{idx}" if id != "" else f"{idx}"
            if ConfigComponent.is_instantiable(v) or ConfigExpression.is_expression(v):
                updated = refs_[sub_id]
                if ConfigComponent.is_instantiable(v) and updated is None:
                    # the component is disabled
                    continue
            else:
                updated = cls.update_config_with_refs(v, sub_id, refs_)
            ret.update({idx: updated}) if isinstance(ret, dict) else ret.append(updated)
        return ret
Ejemplo n.º 4
0
    def find_refs_in_config(
            cls,
            config,
            id: str,
            refs: Optional[Dict[str, int]] = None) -> Dict[str, int]:
        """
        Recursively search all the content of input config item to get the ids of references.
        References mean: the IDs of other config items (``"@XXX"`` in this config item), or the
        sub-item in the config is `instantiable`, or the sub-item in the config is `expression`.
        For `dict` and `list`, recursively check the sub-items.

        Args:
            config: input config content to search.
            id: ID name for the input config item.
            refs: dict of the ID name and count of found references, default to `None`.

        """
        refs_: Dict[str, int] = refs or {}
        if isinstance(config, str):
            for id, count in cls.match_refs_pattern(value=config).items():
                refs_[id] = refs_.get(id, 0) + count
        if not isinstance(config, (list, dict)):
            return refs_
        for k, v in config.items() if isinstance(config,
                                                 dict) else enumerate(config):
            sub_id = f"{id}{cls.sep}{k}" if id != "" else f"{k}"
            if ConfigComponent.is_instantiable(
                    v
            ) or ConfigExpression.is_expression(v) and sub_id not in refs_:
                refs_[sub_id] = 1
            refs_ = cls.find_refs_in_config(v, sub_id, refs_)
        return refs_
Ejemplo n.º 5
0
 def test_circular_references(self):
     locator = ComponentLocator()
     resolver = ReferenceResolver()
     configs = {"A": "@B", "B": "@C", "C": "@A"}
     for k, v in configs.items():
         resolver.add_item(ConfigComponent(config=v, id=k, locator=locator))
     for k in ["A", "B", "C"]:
         with self.assertRaises(ValueError):
             resolver.get_resolved_content(k)
Ejemplo n.º 6
0
def load(
    name: str,
    model_file: Optional[str] = None,
    load_ts_module: bool = False,
    bundle_dir: Optional[PathLike] = None,
    source: str = "github",
    repo: Optional[str] = None,
    progress: bool = True,
    device: Optional[str] = None,
    config_files: Sequence[str] = (),
    net_name: Optional[str] = None,
    **net_kwargs,
):
    """
    Load model weights or TorchScript module of a bundle.

    Args:
        name: bundle name.
        model_file: the relative path of the model weights or TorchScript module within bundle.
            If `None`, "models/model.pt" or "models/model.ts" will be used.
        load_ts_module: a flag to specify if loading the TorchScript module.
        bundle_dir: the directory the weights/TorchScript module will be loaded from.
            Default is `bundle` subfolder under`torch.hub get_dir()`.
        source: the place that saved the bundle.
            If `source` is `github`, the bundle should be within the releases.
        repo: the repo name. If the weights file does not exist locally and `url` is `None`, it must be provided.
            If `source` is `github`, it should be in the form of `repo_owner/repo_name/release_tag`.
            For example: `Project-MONAI/MONAI-extra-test-data/0.8.1`.
        progress: whether to display a progress bar when downloading.
        device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
        config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
            see `_extra_files` in `torch.jit.load` for more details.
        net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
            This argument only works when loading weights.
        net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

    Returns:
        1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights.
        2. If `load_ts_module` is `False` and `net_name` is not `None`,
            return an instantiated network that loaded the weights.
        3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
            the corresponding metadata dict, and extra files dict.
            please check `monai.data.load_net_with_metadata` for more details.

    """
    bundle_dir_ = _process_bundle_dir(bundle_dir)

    if model_file is None:
        model_file = os.path.join(
            "models", "model.ts" if load_ts_module is True else "model.pt")
    full_path = os.path.join(bundle_dir_, name, model_file)
    if not os.path.exists(full_path):
        download(name=name,
                 bundle_dir=bundle_dir_,
                 source=source,
                 repo=repo,
                 progress=progress)

    if device is None:
        device = "cuda:0" if is_available() else "cpu"
    # loading with `torch.jit.load`
    if load_ts_module is True:
        return load_net_with_metadata(full_path,
                                      map_location=torch.device(device),
                                      more_extra_files=config_files)
    # loading with `torch.load`
    model_dict = torch.load(full_path, map_location=torch.device(device))

    if net_name is None:
        return model_dict
    net_kwargs["_target_"] = net_name
    configer = ConfigComponent(config=net_kwargs)
    model = configer.instantiate()
    model.to(device)  # type: ignore
    model.load_state_dict(model_dict)  # type: ignore
    return model
Ejemplo n.º 7
0
def load(
    name: str,
    model_file: Optional[str] = None,
    load_ts_module: bool = False,
    bundle_dir: Optional[PathLike] = None,
    source: str = "github",
    repo: str = "Project-MONAI/model-zoo/hosting_storage_v1",
    progress: bool = True,
    device: Optional[str] = None,
    key_in_ckpt: Optional[str] = None,
    config_files: Sequence[str] = (),
    net_name: Optional[str] = None,
    **net_kwargs,
):
    """
    Load model weights or TorchScript module of a bundle.

    Args:
        name: bundle name.
        model_file: the relative path of the model weights or TorchScript module within bundle.
            If `None`, "models/model.pt" or "models/model.ts" will be used.
        load_ts_module: a flag to specify if loading the TorchScript module.
        bundle_dir: directory the weights/TorchScript module will be loaded from.
            Default is `bundle` subfolder under `torch.hub.get_dir()`.
        source: storage location name. This argument is used when `model_file` is not existing locally and need to be
            downloaded first. "github" is currently the only supported value.
        repo: repo name. This argument is used when `model_file` is not existing locally and need to be
            downloaded first. If `source` is "github", it should be in the form of "repo_owner/repo_name/release_tag".
        progress: whether to display a progress bar when downloading.
        device: target device of returned weights or module, if `None`, prefer to "cuda" if existing.
        key_in_ckpt: for nested checkpoint like `{"model": XXX, "optimizer": XXX, ...}`, specify the key of model
            weights. if not nested checkpoint, no need to set.
        config_files: extra filenames would be loaded. The argument only works when loading a TorchScript module,
            see `_extra_files` in `torch.jit.load` for more details.
        net_name: if not `None`, a corresponding network will be instantiated and load the achieved weights.
            This argument only works when loading weights.
        net_kwargs: other arguments that are used to instantiate the network class defined by `net_name`.

    Returns:
        1. If `load_ts_module` is `False` and `net_name` is `None`, return model weights.
        2. If `load_ts_module` is `False` and `net_name` is not `None`,
            return an instantiated network that loaded the weights.
        3. If `load_ts_module` is `True`, return a triple that include a TorchScript module,
            the corresponding metadata dict, and extra files dict.
            please check `monai.data.load_net_with_metadata` for more details.

    """
    bundle_dir_ = _process_bundle_dir(bundle_dir)

    if model_file is None:
        model_file = os.path.join("models", "model.ts" if load_ts_module is True else "model.pt")
    full_path = os.path.join(bundle_dir_, name, model_file)
    if not os.path.exists(full_path):
        download(name=name, bundle_dir=bundle_dir_, source=source, repo=repo, progress=progress)

    if device is None:
        device = "cuda:0" if is_available() else "cpu"
    # loading with `torch.jit.load`
    if load_ts_module is True:
        return load_net_with_metadata(full_path, map_location=torch.device(device), more_extra_files=config_files)
    # loading with `torch.load`
    model_dict = torch.load(full_path, map_location=torch.device(device))
    if not isinstance(model_dict, Mapping):
        warnings.warn(f"the state dictionary from {full_path} should be a dictionary but got {type(model_dict)}.")
        model_dict = get_state_dict(model_dict)

    if net_name is None:
        return model_dict
    net_kwargs["_target_"] = net_name
    configer = ConfigComponent(config=net_kwargs)
    model = configer.instantiate()
    model.to(device)  # type: ignore
    copy_model_state(dst=model, src=model_dict if key_in_ckpt is None else model_dict[key_in_ckpt])  # type: ignore
    return model