def test_set_prefix(self, device_0, device_1): set_determinism(0) model_one = torch.nn.Sequential(_TestModelOne(10, 20, 3)) model_two = _TestModelTwo(10, 20, 10, 4) model_one.to(device_0) model_two.to(device_1) # test skip layer.bias model_dict, ch, unch = copy_model_state(model_one, model_two, dst_prefix="0.", exclude_vars="layer.bias", inplace=False) model_one.load_state_dict(model_dict) x = np.random.randn(4, 10) x = torch.tensor(x, device=device_0, dtype=torch.float32) output = model_one(x).detach().cpu().numpy() expected = np.array([ [-0.360766, -0.031778, -0.770227], [-0.052683, -0.158559, -0.011493], [-0.376051, -0.224852, -0.063404], [0.597767, -0.679911, 0.19195], ]) np.testing.assert_allclose(output, expected, atol=1e-3) self.assertEqual(len(ch), 2) self.assertEqual(len(unch), 2)
def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ checkpoint = torch.load(self.load_path, map_location=self.map_location) if not self.strict_shape: k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict if len(self.load_dict) == 1 and k not in checkpoint: checkpoint = {k: checkpoint} # skip items that don't match data shape for k, obj in self.load_dict.items(): checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) if engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, " "construct trainer with `max_epochs` larger than checkpoint's epoch count. " "To use checkpoint for inference, no need to load state_dict for the engine." ) engine.state.max_epochs = prior_max_epochs self.logger.info(f"Restored all variables from {self.load_path}")
def test_set_map_across(self, device_0, device_1): set_determinism(0) model_one = _TestModelOne(10, 10, 3) model_two = _TestModelTwo(10, 10, 10, 4) model_one.to(device_0) model_two.to(device_1) # test weight map model_dict, ch, unch = copy_model_state(model_one, model_two, mapping={ "layer_1.weight": "layer.weight", "layer_1.bias": "layer_1.weight" }) model_one.load_state_dict(model_dict) x = np.random.randn(4, 10) x = torch.tensor(x, device=device_0, dtype=torch.float32) output = model_one(x).detach().cpu().numpy() expected = np.array([ [0.8244487, -0.19650555, 0.65723234], [0.71239626, 0.25617486, 0.5247122], [0.24168758, 1.0301148, 0.39089814], [0.25791705, 0.8653245, 0.14833644], ]) np.testing.assert_allclose(output, expected, atol=1e-3) self.assertEqual(len(ch), 2) self.assertEqual(len(unch), 2)
def test_loading_mmar(self, item): if item["name"] == "clara_pt_self_supervised_learning_segmentation": # test the byow model default_model_file = os.path.join("ssl_models_2gpu", "best_metric_model.pt") pretrained_weights = load_from_mmar( item=item["name"], mmar_dir="./", map_location="cpu", api=True, model_file=default_model_file, weights_only=True, ) pretrained_weights = {k.split(".", 1)[1]: v for k, v in pretrained_weights["state_dict"].items()} sys.path.append(os.path.join(f"{item['name']}", "custom")) # custom model folder from vit_network import ViTAutoEnc # pylint: disable=E0401 model = ViTAutoEnc( in_channels=1, img_size=(96, 96, 96), patch_size=(16, 16, 16), pos_embed="conv", hidden_size=768, mlp_dim=3072, ) _, loaded, not_loaded = copy_model_state(model, pretrained_weights) self.assertTrue(len(loaded) > 0 and len(not_loaded) == 0) return if item["name"] == "clara_pt_fed_learning_brain_tumor_mri_segmentation": default_model_file = os.path.join("models", "server", "best_FL_global_model.pt") else: default_model_file = None pretrained_model = load_from_mmar( item=item["name"], mmar_dir="./", map_location="cpu", api=True, model_file=default_model_file ) self.assertTrue(isinstance(pretrained_model, torch.nn.Module))
def test_set_full_state(self, device_0, device_1): set_determinism(0) model_one = _TestModelOne(10, 20, 3) model_two = _TestModelOne(10, 20, 3) model_one.to(device_0) model_two.to(device_1) # test module input model_dict, ch, unch = copy_model_state(model_one, model_two) # test dict input model_dict, ch, unch = copy_model_state(model_dict, model_two) x = np.random.randn(4, 10) x = torch.tensor(x, device=device_0, dtype=torch.float32) output = model_one(x).detach().cpu().numpy() model_two.to(device_0) output_1 = model_two(x).detach().cpu().numpy() np.testing.assert_allclose(output, output_1, atol=1e-3) self.assertEqual(len(ch), 4) self.assertEqual(len(unch), 0)
def test_set_state(self, device_0, device_1): set_determinism(0) model_one = _TestModelOne(10, 20, 3) model_two = _TestModelTwo(10, 20, 10, 4) model_one.to(device_0) model_two.to(device_1) model_dict, ch, unch = copy_model_state(model_one, model_two) x = np.random.randn(4, 10) x = torch.tensor(x, device=device_0, dtype=torch.float32) output = model_one(x).detach().cpu().numpy() expected = np.array([ [-0.36076584, -0.03177825, -0.7702266], [-0.0526831, -0.15855855, -0.01149344], [-0.3760508, -0.22485238, -0.0634037], [0.5977675, -0.67991066, 0.1919502], ]) np.testing.assert_allclose(output, expected, atol=1e-3) self.assertEqual(len(ch), 2) self.assertEqual(len(unch), 2)
def __call__(self, engine: Engine) -> None: """ Args: engine: Ignite Engine, it can be a trainer, validator or evaluator. """ checkpoint = torch.load(self.load_path, map_location=self.map_location) k, _ = list(self.load_dict.items())[0] # single object and checkpoint is directly a state_dict if len(self.load_dict) == 1 and k not in checkpoint: checkpoint = {k: checkpoint} if not self.strict_shape: pop_items: List[str] = [] for k, obj in self.load_dict.items(): if isinstance(obj, torch.nn.Module): # skip items that don't match key name or data shape checkpoint[k] = copy_model_state(obj, checkpoint, inplace=False)[0] else: warnings.warn( "`strict_shape` is False, load checkpoint for model, skip others in `load_dict`." ) pop_items.append(k) for i in pop_items: self.load_dict.pop(i) # save current max epochs setting in the engine, don't overwrite it if larger than max_epochs in checkpoint prior_max_epochs = engine.state.max_epochs Checkpoint.load_objects(to_load=self.load_dict, checkpoint=checkpoint, strict=self.strict) if prior_max_epochs is not None and engine.state.epoch > prior_max_epochs: raise ValueError( f"Epoch count ({engine.state.epoch}) in checkpoint is larger than " f"the `engine.state.max_epochs` ({prior_max_epochs}) of engine. To further train from checkpoint, " "construct trainer with `max_epochs` larger than checkpoint's epoch count. " "To use checkpoint for inference, no need to load state_dict for the engine." ) engine.state.max_epochs = prior_max_epochs self.logger.info(f"Restored all variables from {self.load_path}")
def test_set_exclude_vars(self, device_0, device_1): set_determinism(0) model_one = _TestModelOne(10, 20, 3) model_two = _TestModelTwo(10, 20, 10, 4) model_one.to(device_0) model_two.to(device_1) # test skip layer.bias model_dict, ch, unch = copy_model_state(model_one, model_two, exclude_vars="layer.bias") x = np.random.randn(4, 10) x = torch.tensor(x, device=device_0, dtype=torch.float32) output = model_one(x).detach().cpu().numpy() expected = np.array([ [-0.34172416, 0.0375042, -0.98340976], [-0.03364138, -0.08927619, -0.2246768], [-0.35700908, -0.15556987, -0.27658707], [0.61680925, -0.6106281, -0.02123314], ]) np.testing.assert_allclose(output, expected, atol=1e-3) self.assertEqual(len(ch), 1) self.assertEqual(len(unch), 3)
def load_from_mmar( item, mmar_dir: Optional[PathLike] = None, progress: bool = True, version: int = -1, map_location=None, pretrained=True, weights_only=False, model_key: str = "model", api: bool = True, model_file=None, ): """ Download and extract Medical Model Archive (MMAR) model weights from Nvidia Clara Train. Args: item: the corresponding model item from `MODEL_DESC`. mmar_dir: : target directory to store the MMAR, default is mmars subfolder under `torch.hub get_dir()`. progress: whether to display a progress bar when downloading the content. version: version number of the MMAR. Set it to `-1` to use `item[Keys.VERSION]`. map_location: pytorch API parameter for `torch.load` or `torch.jit.load`. pretrained: whether to load the pretrained weights after initializing a network module. weights_only: whether to load only the weights instead of initializing the network module and assign weights. model_key: a key to search in the model file or config file for the model dictionary. Currently this function assumes that the model dictionary has `{"[name|path]": "test.module", "args": {'kw': 'test'}}`. api: whether to query NGC API to get model infomation. model_file: the relative path to the model file within an MMAR. Examples:: >>> from monai.apps import load_from_mmar >>> unet_model = load_from_mmar("clara_pt_prostate_mri_segmentation_1", mmar_dir=".", map_location="cpu") >>> print(unet_model) See Also: https://docs.nvidia.com/clara/ """ if api: item = { Keys.NAME: get_model_spec(item)[Keys.NAME] if isinstance(item, int) else f"{item}" } if not isinstance(item, Mapping): item = get_model_spec(item) model_dir = download_mmar(item=item, mmar_dir=mmar_dir, progress=progress, version=version, api=api) if model_file is None: model_file = os.path.join("models", "model.pt") model_file = model_dir / item.get(Keys.MODEL_FILE, model_file) logger.info(f'\n*** "{item.get(Keys.NAME)}" available at {model_dir}.') # loading with `torch.jit.load` if model_file.name.endswith(".ts"): if not pretrained: warnings.warn( "Loading a ScriptModule, 'pretrained' option ignored.") if weights_only: warnings.warn( "Loading a ScriptModule, 'weights_only' option ignored.") return torch.jit.load(model_file, map_location=map_location) # loading with `torch.load` model_dict = torch.load(model_file, map_location=map_location) if weights_only: return model_dict.get( model_key, model_dict) # model_dict[model_key] or model_dict directly # 1. search `model_dict['train_config]` for model config spec. model_config = _get_val(dict(model_dict).get("train_conf", {}), key=model_key, default={}) if not model_config or not isinstance(model_config, Mapping): # 2. search json CONFIG_FILE for model config spec. json_path = model_dir / item.get( Keys.CONFIG_FILE, os.path.join("config", "config_train.json")) with open(json_path) as f: conf_dict = json.load(f) conf_dict = dict(conf_dict) model_config = _get_val(conf_dict, key=model_key, default={}) if not model_config: # 3. search `model_dict` for model config spec. model_config = _get_val(dict(model_dict), key=model_key, default={}) if not (model_config and isinstance(model_config, Mapping)): raise ValueError( f"Could not load model config dictionary from config: {item.get(Keys.CONFIG_FILE)}, " f"or from model file: {item.get(Keys.MODEL_FILE)}.") # parse `model_config` for model class and model parameters if model_config.get("name"): # model config section is a "name" model_name = model_config["name"] model_cls = monai_nets.__dict__[model_name] elif model_config.get("path"): # model config section is a "path" # https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html model_module, model_name = model_config.get("path", ".").rsplit(".", 1) model_cls, has_cls = optional_import(module=model_module, name=model_name) if not has_cls: raise ValueError( f"Could not load MMAR model config {model_config.get('path', '')}, " f"Please make sure MMAR's sub-folders in '{model_dir}' is on the PYTHONPATH." "See also: https://docs.nvidia.com/clara/clara-train-sdk/pt/byom.html" ) else: raise ValueError(f"Could not load model config {model_config}.") logger.info(f"*** Model: {model_cls}") model_kwargs = model_config.get("args", None) if model_kwargs: model_inst = model_cls(**model_kwargs) logger.info(f"*** Model params: {model_kwargs}") else: model_inst = model_cls() if pretrained: _, changed, unchanged = copy_model_state(model_inst, model_dict.get( model_key, model_dict), inplace=True) if not (changed and not unchanged): # not all model_inst varaibles are changed logger.warning( f"*** Loading model state -- unchanged: {len(unchanged)}, changed: {len(changed)}." ) logger.info("\n---") doc_url = item.get(Keys.DOC) or _get_ngc_doc_url( item[Keys.NAME], model_prefix="nvidia:med:") logger.info(f"For more information, please visit {doc_url}\n") return model_inst