예제 #1
0
 def test_set_prefix(self, device_0, device_1):
     set_determinism(0)
     model_one = torch.nn.Sequential(_TestModelOne(10, 20, 3))
     model_two = _TestModelTwo(10, 20, 10, 4)
     model_one.to(device_0)
     model_two.to(device_1)
     # test skip layer.bias
     model_dict, ch, unch = copy_model_state(model_one,
                                             model_two,
                                             dst_prefix="0.",
                                             exclude_vars="layer.bias",
                                             inplace=False)
     model_one.load_state_dict(model_dict)
     x = np.random.randn(4, 10)
     x = torch.tensor(x, device=device_0, dtype=torch.float32)
     output = model_one(x).detach().cpu().numpy()
     expected = np.array([
         [-0.360766, -0.031778, -0.770227],
         [-0.052683, -0.158559, -0.011493],
         [-0.376051, -0.224852, -0.063404],
         [0.597767, -0.679911, 0.19195],
     ])
     np.testing.assert_allclose(output, expected, atol=1e-3)
     self.assertEqual(len(ch), 2)
     self.assertEqual(len(unch), 2)
예제 #2
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        checkpoint = torch.load(self.load_path, map_location=self.map_location)

        if not self.strict_shape:
            k, _ = list(self.load_dict.items())[0]
            # single object and checkpoint is directly a state_dict
            if len(self.load_dict) == 1 and k not in checkpoint:
                checkpoint = {k: checkpoint}

            # skip items that don't match data shape
            for k, obj in self.load_dict.items():
                checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0]

        # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
        prior_max_epochs = engine.state.max_epochs
        Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict)
        if engine.state.epoch > prior_max_epochs:
            raise ValueError(
                f"Epoch count ({engine.state.epoch}) in checkpoint is larger than "
                f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, "
                "construct trainer with `max_epochs` larger than checkpoint's epoch count. "
                "To use checkpoint for inference, no need to load state_dict for the engine."
            )
        engine.state.max_epochs = prior_max_epochs

        self.logger.info(f"Restored all variables from {self.load_path}")
예제 #3
0
 def test_set_map_across(self, device_0, device_1):
     set_determinism(0)
     model_one = _TestModelOne(10, 10, 3)
     model_two = _TestModelTwo(10, 10, 10, 4)
     model_one.to(device_0)
     model_two.to(device_1)
     # test weight map
     model_dict, ch, unch = copy_model_state(model_one,
                                             model_two,
                                             mapping={
                                                 "layer_1.weight":
                                                 "layer.weight",
                                                 "layer_1.bias":
                                                 "layer_1.weight"
                                             })
     model_one.load_state_dict(model_dict)
     x = np.random.randn(4, 10)
     x = torch.tensor(x, device=device_0, dtype=torch.float32)
     output = model_one(x).detach().cpu().numpy()
     expected = np.array([
         [0.8244487, -0.19650555, 0.65723234],
         [0.71239626, 0.25617486, 0.5247122],
         [0.24168758, 1.0301148, 0.39089814],
         [0.25791705, 0.8653245, 0.14833644],
     ])
     np.testing.assert_allclose(output, expected, atol=1e-3)
     self.assertEqual(len(ch), 2)
     self.assertEqual(len(unch), 2)
예제 #4
0
    def test_loading_mmar(self, item):
        if item["name"] == "clara_pt_self_supervised_learning_segmentation":  # test the byow model
            default_model_file = os.path.join("ssl_models_2gpu", "best_metric_model.pt")
            pretrained_weights = load_from_mmar(
                item=item["name"],
                mmar_dir="./",
                map_location="cpu",
                api=True,
                model_file=default_model_file,
                weights_only=True,
            )
            pretrained_weights = {k.split(".", 1)[1]: v for k, v in pretrained_weights["state_dict"].items()}
            sys.path.append(os.path.join(f"{item['name']}", "custom"))  # custom model folder
            from vit_network import ViTAutoEnc  # pylint: disable=E0401

            model = ViTAutoEnc(
                in_channels=1,
                img_size=(96, 96, 96),
                patch_size=(16, 16, 16),
                pos_embed="conv",
                hidden_size=768,
                mlp_dim=3072,
            )
            _, loaded, not_loaded = copy_model_state(model, pretrained_weights)
            self.assertTrue(len(loaded) > 0 and len(not_loaded) == 0)
            return
        if item["name"] == "clara_pt_fed_learning_brain_tumor_mri_segmentation":
            default_model_file = os.path.join("models", "server", "best_FL_global_model.pt")
        else:
            default_model_file = None
        pretrained_model = load_from_mmar(
            item=item["name"], mmar_dir="./", map_location="cpu", api=True, model_file=default_model_file
        )
        self.assertTrue(isinstance(pretrained_model, torch.nn.Module))
예제 #5
0
 def test_set_full_state(self, device_0, device_1):
     set_determinism(0)
     model_one = _TestModelOne(10, 20, 3)
     model_two = _TestModelOne(10, 20, 3)
     model_one.to(device_0)
     model_two.to(device_1)
     # test module input
     model_dict, ch, unch = copy_model_state(model_one, model_two)
     # test dict input
     model_dict, ch, unch = copy_model_state(model_dict, model_two)
     x = np.random.randn(4, 10)
     x = torch.tensor(x, device=device_0, dtype=torch.float32)
     output = model_one(x).detach().cpu().numpy()
     model_two.to(device_0)
     output_1 = model_two(x).detach().cpu().numpy()
     np.testing.assert_allclose(output, output_1, atol=1e-3)
     self.assertEqual(len(ch), 4)
     self.assertEqual(len(unch), 0)
예제 #6
0
 def test_set_state(self, device_0, device_1):
     set_determinism(0)
     model_one = _TestModelOne(10, 20, 3)
     model_two = _TestModelTwo(10, 20, 10, 4)
     model_one.to(device_0)
     model_two.to(device_1)
     model_dict, ch, unch = copy_model_state(model_one, model_two)
     x = np.random.randn(4, 10)
     x = torch.tensor(x, device=device_0, dtype=torch.float32)
     output = model_one(x).detach().cpu().numpy()
     expected = np.array([
         [-0.36076584, -0.03177825, -0.7702266],
         [-0.0526831, -0.15855855, -0.01149344],
         [-0.3760508, -0.22485238, -0.0634037],
         [0.5977675, -0.67991066, 0.1919502],
     ])
     np.testing.assert_allclose(output, expected, atol=1e-3)
     self.assertEqual(len(ch), 2)
     self.assertEqual(len(unch), 2)
예제 #7
0
    def __call__(self, engine: Engine) -> None:
        """
        Args:
            engine: Ignite Engine, it can be a trainer, validator or evaluator.
        """
        checkpoint = torch.load(self.load_path, map_location=self.map_location)

        k, _ = list(self.load_dict.items())[0]
        # single object and checkpoint is directly a state_dict
        if len(self.load_dict) == 1 and k not in checkpoint:
            checkpoint = {k: checkpoint}

        if not self.strict_shape:
            pop_items: List[str] = []
            for k, obj in self.load_dict.items():
                if isinstance(obj, torch.nn.Module):
                    # skip items that don't match key name or data shape
                    checkpoint[k] = copy_model_state(obj,
                                                     checkpoint,
                                                     inplace=False)[0]
                else:
                    warnings.warn(
                        "`strict_shape` is False, load checkpoint for model, skip others in `load_dict`."
                    )
                    pop_items.append(k)
            for i in pop_items:
                self.load_dict.pop(i)

        # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint
        prior_max_epochs = engine.state.max_epochs
        Checkpoint.load_objects(to_load=self.load_dict,
                                checkpoint=checkpoint,
                                strict=self.strict)
        if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs:
            raise ValueError(
                f"Epoch count ({engine.state.epoch}) in checkpoint is larger than "
                f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, "
                "construct trainer with `max_epochs` larger than checkpoint's epoch count. "
                "To use checkpoint for inference, no need to load state_dict for the engine."
            )
        engine.state.max_epochs = prior_max_epochs

        self.logger.info(f"Restored all variables from {self.load_path}")
예제 #8
0
 def test_set_exclude_vars(self, device_0, device_1):
     set_determinism(0)
     model_one = _TestModelOne(10, 20, 3)
     model_two = _TestModelTwo(10, 20, 10, 4)
     model_one.to(device_0)
     model_two.to(device_1)
     # test skip layer.bias
     model_dict, ch, unch = copy_model_state(model_one,
                                             model_two,
                                             exclude_vars="layer.bias")
     x = np.random.randn(4, 10)
     x = torch.tensor(x, device=device_0, dtype=torch.float32)
     output = model_one(x).detach().cpu().numpy()
     expected = np.array([
         [-0.34172416, 0.0375042, -0.98340976],
         [-0.03364138, -0.08927619, -0.2246768],
         [-0.35700908, -0.15556987, -0.27658707],
         [0.61680925, -0.6106281, -0.02123314],
     ])
     np.testing.assert_allclose(output, expected, atol=1e-3)
     self.assertEqual(len(ch), 1)
     self.assertEqual(len(unch), 3)
예제 #9
0
파일: mmars.py 프로젝트: Nic-Ma/MONAI
def load_from_mmar(
    item,
    mmar_dir: Optional[PathLike] = None,
    progress: bool = True,
    version: int = -1,
    map_location=None,
    pretrained=True,
    weights_only=False,
    model_key: str = "model",
    api: bool = True,
    model_file=None,
):
    """
    Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train.

    Args:
        item: the corresponding model item from `MODEL_DESC`.
        mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`.
        progress: whether to display a progress bar when downloading the content.
        version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`.
        map_location: pytorch API parameter for `torch.load` or `torch.jit.load`.
        pretrained: whether to load the pretrained weights after initializing a network module.
        weights_only: whether to load only the weights instead of initializing the network module and assign weights.
        model_key: a key to search in the model file or config file for the model dictionary.
            Currently this function assumes that the model dictionary has
            `{"[name|path]": "test.module", "args": {'kw': 'test'}}`.
        api: whether to query NGC API to get model infomation.
        model_file: the relative path to the model file within an MMAR.

    Examples::
        >>> from monai.apps import load_from_mmar
        >>> unet_model = load_from_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".", map_location="cpu")
        >>> print(unet_model)

    See Also:
        https://docs.nvidia.com/clara/
    """
    if api:
        item = {
            Keys.NAME:
            get_model_spec(item)[Keys.NAME]
            if isinstance(item, int) else f"{item}"
        }
    if not isinstance(item, Mapping):
        item = get_model_spec(item)
    model_dir = download_mmar(item=item,
                              mmar_dir=mmar_dir,
                              progress=progress,
                              version=version,
                              api=api)
    if model_file is None:
        model_file = os.path.join("models", "model.pt")
    model_file = model_dir / item.get(Keys.MODEL_FILE, model_file)
    logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.')

    # loading with `torch.jit.load`
    if model_file.name.endswith(".ts"):
        if not pretrained:
            warnings.warn(
                "Loading a ScriptModule, 'pretrained' option ignored.")
        if weights_only:
            warnings.warn(
                "Loading a ScriptModule, 'weights_only' option ignored.")
        return torch.jit.load(model_file, map_location=map_location)

    # loading with `torch.load`
    model_dict = torch.load(model_file, map_location=map_location)
    if weights_only:
        return model_dict.get(
            model_key,
            model_dict)  # model_dict[model_key] or model_dict directly

    # 1. search `model_dict['train_config]` for model config spec.
    model_config = _get_val(dict(model_dict).get("train_conf", {}),
                            key=model_key,
                            default={})
    if not model_config or not isinstance(model_config, Mapping):
        # 2. search json CONFIG_FILE for model config spec.
        json_path = model_dir / item.get(
            Keys.CONFIG_FILE, os.path.join("config", "config_train.json"))
        with open(json_path) as f:
            conf_dict = json.load(f)
        conf_dict = dict(conf_dict)
        model_config = _get_val(conf_dict, key=model_key, default={})
    if not model_config:
        # 3. search `model_dict` for model config spec.
        model_config = _get_val(dict(model_dict), key=model_key, default={})

    if not (model_config and isinstance(model_config, Mapping)):
        raise ValueError(
            f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, "
            f"or from model file: {item.get(Keys.MODEL_FILE)}.")

    # parse `model_config` for model class and model parameters
    if model_config.get("name"):  # model config section is a "name"
        model_name = model_config["name"]
        model_cls = monai_nets.__dict__[model_name]
    elif model_config.get("path"):  # model config section is a "path"
        # https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html
        model_module, model_name = model_config.get("path", ".").rsplit(".", 1)
        model_cls, has_cls = optional_import(module=model_module,
                                             name=model_name)
        if not has_cls:
            raise ValueError(
                f"Could not load MMAR model config {model_config.get('path', '')}, "
                f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH."
                "See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html"
            )
    else:
        raise ValueError(f"Could not load model config {model_config}.")

    logger.info(f"*** Model: {model_cls}")
    model_kwargs = model_config.get("args", None)
    if model_kwargs:
        model_inst = model_cls(**model_kwargs)
        logger.info(f"*** Model params: {model_kwargs}")
    else:
        model_inst = model_cls()
    if pretrained:
        _, changed, unchanged = copy_model_state(model_inst,
                                                 model_dict.get(
                                                     model_key, model_dict),
                                                 inplace=True)
        if not (changed
                and not unchanged):  # not all model_inst varaibles are changed
            logger.warning(
                f"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}."
            )
    logger.info("\n---")
    doc_url = item.get(Keys.DOC) or _get_ngc_doc_url(
        item[Keys.NAME], model_prefix="nvidia:med:")
    logger.info(f"For more information, please visit {doc_url}\n")
    return model_inst