Ejemplo n.º 1
0
    def test_regex_matches_are_initialized_correctly(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                self.conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        # Make sure we handle regexes properly
        json_params = """{"initializer": [
        ["conv", {"type": "constant", "val": 5}],
        ["funky_na.*bi", {"type": "constant", "val": 7}]
        ]}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for parameter in model.conv.parameters():
            assert torch.equal(parameter.data,
                               torch.ones(parameter.size()) * 5)

        parameter = model.linear_1_with_funky_name.bias
        assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
Ejemplo n.º 2
0
    def test_regex_match_prevention_prevents_and_overrides(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1 = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                # typical actual usage: modules loaded from allenlp.model.load(..)
                self.linear_3_transfer = torch.nn.Linear(5, 10)
                self.linear_4_transfer = torch.nn.Linear(10, 5)
                self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        json_params = """{"initializer": [
        [".*linear.*", {"type": "constant", "val": 10}],
        [".*conv.*", {"type": "constant", "val": 10}],
        [".*_transfer.*", "prevent"],
        [".*pretrained.*",{"type": "prevent"}]
        ]}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for module in [model.linear_1, model.linear_2]:
            for parameter in module.parameters():
                assert torch.equal(parameter.data,
                                   torch.ones(parameter.size()) * 10)

        transfered_modules = [
            model.linear_3_transfer, model.linear_4_transfer,
            model.pretrained_conv
        ]

        for module in transfered_modules:
            for parameter in module.parameters():
                assert not torch.equal(parameter.data,
                                       torch.ones(parameter.size()) * 10)
    def from_params(cls, vocab: Vocabulary, params: Params) -> 'BiattentiveClassificationNetwork':  # type: ignore
        # pylint: disable=arguments-differ
        embedder_params = params.pop("text_field_embedder")
        text_field_embedder = TextFieldEmbedder.from_params(vocab=vocab, params=embedder_params)
        embedding_dropout = params.pop("embedding_dropout")
        pre_encode_feedforward = FeedForward.from_params(params.pop("pre_encode_feedforward"))
        encoder = Seq2SeqEncoder.from_params(params.pop("encoder"))
        integrator = Seq2SeqEncoder.from_params(params.pop("integrator"))
        integrator_dropout = params.pop("integrator_dropout")

        output_layer_params = params.pop("output_layer")
        if "activations" in output_layer_params:
            output_layer = FeedForward.from_params(output_layer_params)
        else:
            output_layer = Maxout.from_params(output_layer_params)

        elmo = params.pop("elmo", None)
        if elmo is not None:
            elmo = Elmo.from_params(elmo)
        use_input_elmo = params.pop_bool("use_input_elmo", False)
        use_integrator_output_elmo = params.pop_bool("use_integrator_output_elmo", False)

        initializer = InitializerApplicator.from_params(params.pop('initializer', []))
        regularizer = RegularizerApplicator.from_params(params.pop('regularizer', []))
        params.assert_empty(cls.__name__)

        return cls(vocab=vocab,
                   text_field_embedder=text_field_embedder,
                   embedding_dropout=embedding_dropout,
                   pre_encode_feedforward=pre_encode_feedforward,
                   encoder=encoder,
                   integrator=integrator,
                   integrator_dropout=integrator_dropout,
                   output_layer=output_layer,
                   elmo=elmo,
                   use_input_elmo=use_input_elmo,
                   use_integrator_output_elmo=use_integrator_output_elmo,
                   initializer=initializer,
                   regularizer=regularizer)