Exemple #1
0
    def load(self,
             model_dir: str,
             network_map: Dict[str, str] = None,
             version: int = -1):
        """
        Load models.

        An example of network map::

            {"restorable_model_1": "file_name_1",
             "restorable_model_2": "file_name_2"}

        Get keys by calling ``<Class name>.get_restorable()``

        Args:
            model_dir: Save directory.
            network_map: Key is module name, value is saved name.
            version: Version number of the save to be loaded.
        """
        network_map = {} if network_map is None else network_map
        restore_map = {}
        for r in self._is_restorable:
            if r in network_map:
                restore_map[network_map[r]] = getattr(self, r)
            else:
                default_logger.warning(
                    'Load path for module "{}" is not specified, '
                    "module name is used.".format(r))
                restore_map[r] = getattr(self, r)
        prep_load_model(model_dir, restore_map, version)
Exemple #2
0
def test_prep_load_model(tmpdir):
    tmp_dir = str(tmpdir.make_numbered_dir())
    tmp_dir2 = str(tmpdir.make_numbered_dir())

    # create example model directory
    with t.no_grad():
        model = t.nn.Linear(100, 100, bias=False)
        model.weight.fill_(0)
        t.save(model, join(tmp_dir, "model_0.pt"))
        model.weight.fill_(1)
        t.save(model, join(tmp_dir, "model_100.pt"))

    with pytest.raises(RuntimeError, match="Model directory doesn't exist"):
        prep_load_model(join(tmp_dir, "not_exist_dir"), {"model": model})

    # load a specific version
    prep_load_model(tmp_dir, {"model": model}, version=0)
    assert t.all(model.weight == 0)

    # load a non-exist version in a directory with valid models
    # will load version 100
    prep_load_model(tmp_dir, {"model": model}, version=50)
    assert t.all(model.weight == 1)

    # load the newest version
    prep_load_model(tmp_dir, {"model": model})
    assert t.all(model.weight == 1)

    # load a non-exist version in a directory with invalid models
    # eg: cannot find the same version for all models in the model map
    with pytest.raises(RuntimeError, match="Cannot find a valid version"):
        prep_load_model(tmp_dir2, {"model": model})
    prep_load_model(tmp_dir2, {"model": model}, quiet=True)