Example #1
0
    def test_remove_keys_from_params(self):
        filename = self.FIXTURES_ROOT / "simple_tagger" / "experiment.json"
        params = Params.from_file(filename)

        assert params["data_loader"]["batch_sampler"]["type"] == "bucket"
        assert params["data_loader"]["batch_sampler"]["batch_size"] == 80

        remove_keys_from_params(params, keys=["batch_size"])
        assert "batch_size" not in params["data_loader"]["batch_sampler"]

        remove_keys_from_params(params, keys=["type", "batch_size"])
        assert "type" not in params["data_loader"]["batch_sampler"]

        remove_keys_from_params(params, keys=["data_loader"])
        assert "data_loader" not in params
Example #2
0
def remove_pretrained_embedding_params(params: Params):
    """This function only exists for backwards compatibility.
    Please use `remove_weights_related_keys_from_params()` instead."""
    remove_keys_from_params(params, ["pretrained_file"])
Example #3
0
def remove_weights_related_keys_from_params(
    params: Params, keys: List[str] = ["pretrained_file", "initializer"]
):
    remove_keys_from_params(params, keys)
Example #4
0
    def _load(
        cls,
        config: Params,
        serialization_dir: Union[str, PathLike],
        weights_file: Optional[Union[str, PathLike]] = None,
        cuda_device: int = -1,
    ) -> "Model":
        """
        Instantiates an already-trained model, based on the experiment
        configuration and some optional overrides.
        """
        weights_file = weights_file or os.path.join(serialization_dir, _DEFAULT_WEIGHTS)

        # Load vocabulary from file
        vocab_dir = os.path.join(serialization_dir, "vocabulary")
        # If the config specifies a vocabulary subclass, we need to use it.
        vocab_params = config.get("vocabulary", Params({}))
        vocab_choice = vocab_params.pop_choice("type", Vocabulary.list_available(), True)
        vocab_class, _ = Vocabulary.resolve_class_name(vocab_choice)
        vocab = vocab_class.from_files(
            vocab_dir, vocab_params.get("padding_token"), vocab_params.get("oov_token")
        )

        model_params = config.get("model")

        # The experiment config tells us how to _train_ a model, including where to get pre-trained
        # embeddings/weights from. We're now _loading_ the model, so those weights will already be
        # stored in our model. We don't need any pretrained weight file or initializers anymore,
        # and we don't want the code to look for it, so we remove it from the parameters here.
        remove_keys_from_params(model_params)
        model = Model.from_params(vocab=vocab, params=model_params)

        # Force model to cpu or gpu, as appropriate, to make sure that the embeddings are
        # in sync with the weights
        if cuda_device >= 0:
            model.cuda(cuda_device)
        else:
            model.cpu()

        # If vocab+embedding extension was done, the model initialized from from_params
        # and one defined by state dict in weights_file might not have same embedding shapes.
        # Eg. when model embedder module was transferred along with vocab extension, the
        # initialized embedding weight shape would be smaller than one in the state_dict.
        # So calling model embedding extension is required before load_state_dict.
        # If vocab and model embeddings are in sync, following would be just a no-op.
        model.extend_embedder_vocab()

        # Load state dict. We pass `strict=False` so PyTorch doesn't raise a RuntimeError
        # if the state dict is missing keys because we handle this case below.
        model_state = torch.load(weights_file, map_location=util.device_mapping(cuda_device))
        missing_keys, unexpected_keys = model.load_state_dict(model_state, strict=False)

        # Modules might define a class variable called `authorized_missing_keys`,
        # a list of regex patterns, that tells us to ignore missing keys that match
        # any of the patterns.
        # We sometimes need this in order to load older models with newer versions of AllenNLP.

        def filter_out_authorized_missing_keys(module, prefix=""):
            nonlocal missing_keys
            for pat in getattr(module.__class__, "authorized_missing_keys", None) or []:
                missing_keys = [
                    k
                    for k in missing_keys
                    if k.startswith(prefix) and re.search(pat[len(prefix) :], k) is None
                ]
            for name, child in module._modules.items():
                if child is not None:
                    filter_out_authorized_missing_keys(child, prefix + name + ".")

        filter_out_authorized_missing_keys(model)

        if unexpected_keys or missing_keys:
            raise RuntimeError(
                f"Error loading state dict for {model.__class__.__name__}\n\t"
                f"Missing keys: {missing_keys}\n\t"
                f"Unexpected keys: {unexpected_keys}"
            )

        return model