def load(self, model_dir: str, network_map: Dict[str, str] = None, version: int = -1): """ Load models. An example of network map:: {"restorable_model_1": "file_name_1", "restorable_model_2": "file_name_2"} Get keys by calling ``<Class name>.get_restorable()`` Args: model_dir: Save directory. network_map: Key is module name, value is saved name. version: Version number of the save to be loaded. """ network_map = {} if network_map is None else network_map restore_map = {} for r in self._is_restorable: if r in network_map: restore_map[network_map[r]] = getattr(self, r) else: default_logger.warning( 'Load path for module "{}" is not specified, ' "module name is used.".format(r)) restore_map[r] = getattr(self, r) prep_load_model(model_dir, restore_map, version)
def test_prep_load_model(tmpdir): tmp_dir = str(tmpdir.make_numbered_dir()) tmp_dir2 = str(tmpdir.make_numbered_dir()) # create example model directory with t.no_grad(): model = t.nn.Linear(100, 100, bias=False) model.weight.fill_(0) t.save(model, join(tmp_dir, "model_0.pt")) model.weight.fill_(1) t.save(model, join(tmp_dir, "model_100.pt")) with pytest.raises(RuntimeError, match="Model directory doesn't exist"): prep_load_model(join(tmp_dir, "not_exist_dir"), {"model": model}) # load a specific version prep_load_model(tmp_dir, {"model": model}, version=0) assert t.all(model.weight == 0) # load a non-exist version in a directory with valid models # will load version 100 prep_load_model(tmp_dir, {"model": model}, version=50) assert t.all(model.weight == 1) # load the newest version prep_load_model(tmp_dir, {"model": model}) assert t.all(model.weight == 1) # load a non-exist version in a directory with invalid models # eg: cannot find the same version for all models in the model map with pytest.raises(RuntimeError, match="Cannot find a valid version"): prep_load_model(tmp_dir2, {"model": model}) prep_load_model(tmp_dir2, {"model": model}, quiet=True)