Exemplo n.º 1
0
    def test_from_params(self):
        # Save a vocab to check we can load it from_params.
        vocab_dir = self.TEST_DIR / 'vocab_save'
        vocab = Vocabulary(non_padded_namespaces=["a", "c"])
        vocab.add_token_to_namespace("a0", namespace="a")  # non-padded, should start at 0
        vocab.add_token_to_namespace("a1", namespace="a")
        vocab.add_token_to_namespace("a2", namespace="a")
        vocab.add_token_to_namespace("b2", namespace="b")  # padded, should start at 2
        vocab.add_token_to_namespace("b3", namespace="b")
        vocab.save_to_files(vocab_dir)

        params = Params({"directory_path": vocab_dir})
        vocab2 = Vocabulary.from_params(params)
        assert vocab.get_index_to_token_vocabulary("a") == vocab2.get_index_to_token_vocabulary("a")
        assert vocab.get_index_to_token_vocabulary("b") == vocab2.get_index_to_token_vocabulary("b")

        # Test case where we build a vocab from a dataset.
        vocab2 = Vocabulary.from_params(Params({}), self.dataset)
        assert vocab2.get_index_to_token_vocabulary("tokens") == {0: '@@PADDING@@',
                                                                  1: '@@UNKNOWN@@',
                                                                  2: 'a', 3: 'c', 4: 'b'}
        # Test from_params raises when we have neither a dataset and a vocab_directory.
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(Params({}))

        # Test from_params raises when there are any other dict keys
        # present apart from 'directory_path' and we aren't calling from_dataset.
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(Params({"directory_path": vocab_dir, "min_count": {'tokens': 2}}))
Exemplo n.º 2
0
 def test_no_metric_wrapper_can_support_none_for_metrics(self):
     model = torch.nn.Sequential(torch.nn.Linear(10, 10))
     lrs = LearningRateScheduler.from_params(
         Optimizer.from_params(model.named_parameters(),
                               Params({"type": "adam"})),
         Params({
             "type": "step",
             "step_size": 1
         }))
     lrs.step(None, None)
Exemplo n.º 3
0
    def test_reduce_on_plateau_error_throw_when_no_metrics_exist(self):
        model = torch.nn.Sequential(torch.nn.Linear(10, 10))
        with self.assertRaises(ConfigurationError) as context:
            LearningRateScheduler.from_params(
                Optimizer.from_params(model.named_parameters(),
                                      Params({"type": "adam"})),
                Params({"type": "reduce_on_plateau"})).step(None, None)

        self.assertTrue(
            'The reduce_on_plateau learning rate scheduler requires a validation metric'
            in str(context.exception))
Exemplo n.º 4
0
 def test_noam_learning_rate_schedule_does_not_crash(self):
     model = torch.nn.Sequential(torch.nn.Linear(10, 10))
     lrs = LearningRateScheduler.from_params(
         Optimizer.from_params(model.named_parameters(),
                               Params({"type": "adam"})),
         Params({
             "type": "noam",
             "model_size": 10,
             "warmup_steps": 2000
         }))
     lrs.step(None)
     lrs.step_batch(None)
Exemplo n.º 5
0
    def test_regex_matches_are_initialized_correctly(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1_with_funky_name = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                self.conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        # Make sure we handle regexes properly
        json_params = """{"initializer": [
        ["conv", {"type": "constant", "val": 5}],
        ["funky_na.*bi", {"type": "constant", "val": 7}]
        ]}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for parameter in model.conv.parameters():
            assert torch.equal(parameter.data,
                               torch.ones(parameter.size()) * 5)

        parameter = model.linear_1_with_funky_name.bias
        assert torch.equal(parameter.data, torch.ones(parameter.size()) * 7)
Exemplo n.º 6
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(vocab=self.vocab,
                                           params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
Exemplo n.º 7
0
 def test_can_optimise_model_with_dense_and_sparse_params(self):
     optimizer_params = Params({"type": "dense_sparse_adam"})
     parameters = [[n, p] for n, p in self.model.named_parameters()
                   if p.requires_grad]
     optimizer = Optimizer.from_params(parameters, optimizer_params)
     iterator = BasicIterator(2)
     iterator.index_with(self.vocab)
     Trainer(self.model, optimizer, iterator, self.instances).train()
Exemplo n.º 8
0
 def test_optimizer_basic(self):
     optimizer_params = Params({"type": "sgd", "lr": 1})
     parameters = [[n, p] for n, p in self.model.named_parameters()
                   if p.requires_grad]
     optimizer = Optimizer.from_params(parameters, optimizer_params)
     param_groups = optimizer.param_groups
     assert len(param_groups) == 1
     assert param_groups[0]['lr'] == 1
Exemplo n.º 9
0
    def test_from_params(self):
        optim = self._get_optimizer()
        sched = LearningRateScheduler.from_params(
            optim, Params({
                "type": "cosine",
                "t_max": 5
            })).lr_scheduler

        assert sched.t_max == 5
        assert sched._initialized is True

        # Learning should be unchanged after initializing scheduler.
        assert optim.param_groups[0]["lr"] == 1.0

        with self.assertRaises(TypeError):
            # t_max is required.
            LearningRateScheduler.from_params(optim, Params({"type":
                                                             "cosine"}))
Exemplo n.º 10
0
 def test_as_ordered_dict(self):
     # keyD > keyC > keyE; keyDA > keyDB; Next all other keys alphabetically
     preference_orders = [["keyD", "keyC", "keyE"], ["keyDA", "keyDB"]]
     params = Params({"keyC": "valC", "keyB": "valB", "keyA": "valA", "keyE": "valE",
                      "keyD": {"keyDB": "valDB", "keyDA": "valDA"}})
     ordered_params_dict = params.as_ordered_dict(preference_orders)
     expected_ordered_params_dict = OrderedDict({'keyD': {'keyDA': 'valDA', 'keyDB': 'valDB'},
                                                 'keyC': 'valC', 'keyE': 'valE',
                                                 'keyA': 'valA', 'keyB': 'valB'})
     assert json.dumps(ordered_params_dict) == json.dumps(expected_ordered_params_dict)
Exemplo n.º 11
0
 def test_span_f1_can_build_from_params(self):
     params = Params({
         "type": "span_f1",
         "tag_namespace": "tags",
         "ignore_classes": ["V"]
     })
     metric = Metric.from_params(params=params, vocabulary=self.vocab)
     assert metric._ignore_classes == ["V"]
     assert metric._label_vocabulary == self.vocab.get_index_to_token_vocabulary(
         "tags")
Exemplo n.º 12
0
 def test_bidirectional_endpoint_span_extractor_can_build_from_params(self):
     params = Params({
         "type": "bidirectional_endpoint",
         "input_dim": 4,
         "num_width_embeddings": 5,
         "span_width_embedding_dim": 3
     })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, BidirectionalEndpointSpanExtractor)
     assert extractor.get_output_dim() == 2 + 2 + 3
Exemplo n.º 13
0
    def test_as_flat_dict(self):
        params = Params({
                'a': 10,
                'b': {
                        'c': 20,
                        'd': 'stuff'
                }
        }).as_flat_dict()

        assert params == {'a': 10, 'b.c': 20, 'b.d': 'stuff'}
 def test_endpoint_span_extractor_can_build_from_params(self):
     params = Params({
         "type": "endpoint",
         "input_dim": 7,
         "num_width_embeddings": 5,
         "span_width_embedding_dim": 3
     })
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, EndpointSpanExtractor)
     assert extractor.get_output_dim(
     ) == 17  # 2 * input_dim + span_width_embedding_dim
Exemplo n.º 15
0
    def test_multi_head_self_attention_can_build_from_params(self):
        params = Params({
            "num_heads": 3,
            "input_dim": 2,
            "attention_dim": 3,
            "values_dim": 6
        })

        encoder = MultiHeadSelfAttention.from_params(params)
        assert isinstance(encoder, MultiHeadSelfAttention)
        assert encoder.get_input_dim() == 2
        assert encoder.get_output_dim() == 2
Exemplo n.º 16
0
    def test_stacked_bidirectional_lstm_can_build_from_params(self):
        params = Params({
            "type": "stacked_bidirectional_lstm",
            "input_size": 5,
            "hidden_size": 9,
            "num_layers": 3
        })
        encoder = Seq2SeqEncoder.from_params(params)

        assert encoder.get_input_dim() == 5
        assert encoder.get_output_dim() == 18
        assert encoder.is_bidirectional
Exemplo n.º 17
0
    def test_max_vocab_size_dict(self):
        params = Params({
                "max_vocab_size": {
                        "tokens": 1,
                        "characters": 20
                }
        })

        vocab = Vocabulary.from_params(params=params, instances=self.dataset)
        words = vocab.get_index_to_token_vocabulary().values()
        # Additional 2 tokens are '@@PADDING@@' and '@@UNKNOWN@@' by default
        assert len(words) == 3
Exemplo n.º 18
0
 def test_to_file(self):
     # Test to_file works with or without preference orders
     params_dict = {"keyA": "valA", "keyB": "valB"}
     expected_ordered_params_dict = OrderedDict({"keyB": "valB", "keyA": "valA"})
     params = Params(params_dict)
     file_path = self.TEST_DIR / 'config.jsonnet'
     # check with preference orders
     params.to_file(file_path, [["keyB", "keyA"]])
     with open(file_path, "r") as handle:
         ordered_params_dict = OrderedDict(json.load(handle))
     assert json.dumps(expected_ordered_params_dict) == json.dumps(ordered_params_dict)
     # check without preference orders doesn't give error
     params.to_file(file_path)
Exemplo n.º 19
0
    def test_from_params_extend_config(self):

        vocab_dir = self.TEST_DIR / 'vocab_save'
        original_vocab = Vocabulary(non_padded_namespaces=["tokens"])
        original_vocab.add_token_to_namespace("a", namespace="tokens")
        original_vocab.save_to_files(vocab_dir)

        text_field = TextField([Token(t) for t in ["a", "b"]],
                               {"tokens": SingleIdTokenIndexer("tokens")})
        instances = Batch([Instance({"text": text_field})])

        # If you ask to extend vocab from `directory_path`, instances must be passed
        # in Vocabulary constructor, or else there is nothing to extend to.
        params = Params({"directory_path": vocab_dir, "extend": True})
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(params)

        # If you ask to extend vocab, `directory_path` key must be present in params,
        # or else there is nothing to extend from.
        params = Params({"extend": True})
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(params, instances)
Exemplo n.º 20
0
    def test_trainer_can_run_with_lr_scheduler(self):

        lr_params = Params({"type": "reduce_on_plateau"})
        lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, lr_params)
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          validation_metric="-loss",
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        trainer.train()
Exemplo n.º 21
0
    def test_max_vocab_size_partial_dict(self):
        indexers = {"tokens": SingleIdTokenIndexer(), "token_characters": TokenCharactersIndexer()}
        instance = Instance({
                'text': TextField([Token(w) for w in 'Abc def ghi jkl mno pqr stu vwx yz'.split(' ')], indexers)
        })
        dataset = Batch([instance])
        params = Params({
                "max_vocab_size": {
                        "tokens": 1
                }
        })

        vocab = Vocabulary.from_params(params=params, instances=dataset)
        assert len(vocab.get_index_to_token_vocabulary("tokens").values()) == 3 # 1 + 2
        assert len(vocab.get_index_to_token_vocabulary("token_characters").values()) == 28 # 26 + 2
Exemplo n.º 22
0
    def test_registrability(self):

        @Vocabulary.register('my-vocabulary')
        class MyVocabulary:
            @classmethod
            def from_params(cls, params, instances=None):
                # pylint: disable=unused-argument
                return MyVocabulary()


        params = Params({'type': 'my-vocabulary'})

        instance = Instance(fields={})

        vocab = Vocabulary.from_params(params=params, instances=[instance])

        assert isinstance(vocab, MyVocabulary)
Exemplo n.º 23
0
    def test_invalid_vocab_extension(self):
        vocab_dir = self.TEST_DIR / 'vocab_save'
        original_vocab = Vocabulary(non_padded_namespaces=["tokens1"])
        original_vocab.add_token_to_namespace("a", namespace="tokens1")
        original_vocab.add_token_to_namespace("b", namespace="tokens1")
        original_vocab.add_token_to_namespace("p", namespace="tokens2")
        original_vocab.save_to_files(vocab_dir)
        text_field1 = TextField([Token(t) for t in ["a" "c"]],
                                {"tokens1": SingleIdTokenIndexer("tokens1")})
        text_field2 = TextField([Token(t) for t in ["p", "q", "r"]],
                                {"tokens2": SingleIdTokenIndexer("tokens2")})
        instances = Batch([Instance({"text1": text_field1, "text2": text_field2})])

        # Following 2 should give error: token1 is non-padded in original_vocab but not in instances
        params = Params({"directory_path": vocab_dir, "extend": True,
                         "non_padded_namespaces": []})
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(params, instances)
        with pytest.raises(ConfigurationError):
            extended_vocab = copy.copy(original_vocab)
            params = Params({"non_padded_namespaces": []})
            extended_vocab.extend_from_instances(params, instances)
        with pytest.raises(ConfigurationError):
            extended_vocab = copy.copy(original_vocab)
            extended_vocab._extend(non_padded_namespaces=[],
                                   tokens_to_add={"tokens1": ["a"], "tokens2": ["p"]})

        # Following 2 should not give error: overlapping namespaces have same padding setting
        params = Params({"directory_path": vocab_dir, "extend": True,
                         "non_padded_namespaces": ["tokens1"]})
        Vocabulary.from_params(params, instances)
        extended_vocab = copy.copy(original_vocab)
        params = Params({"non_padded_namespaces": ["tokens1"]})
        extended_vocab.extend_from_instances(params, instances)
        extended_vocab = copy.copy(original_vocab)
        extended_vocab._extend(non_padded_namespaces=["tokens1"],
                               tokens_to_add={"tokens1": ["a"], "tokens2": ["p"]})

        # Following 2 should give error: token1 is padded in instances but not in original_vocab
        params = Params({"directory_path": vocab_dir, "extend": True,
                         "non_padded_namespaces": ["tokens1", "tokens2"]})
        with pytest.raises(ConfigurationError):
            _ = Vocabulary.from_params(params, instances)
        with pytest.raises(ConfigurationError):
            extended_vocab = copy.copy(original_vocab)
            params = Params({"non_padded_namespaces": ["tokens1", "tokens2"]})
            extended_vocab.extend_from_instances(params, instances)
        with pytest.raises(ConfigurationError):
            extended_vocab = copy.copy(original_vocab)
            extended_vocab._extend(non_padded_namespaces=["tokens1", "tokens2"],
                                   tokens_to_add={"tokens1": ["a"], "tokens2": ["p"]})
Exemplo n.º 24
0
 def setUp(self):
     super(TestOptimizer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(vocab=vocab,
                                           params=self.model_params)
Exemplo n.º 25
0
    def test_regex_match_prevention_prevents_and_overrides(self):
        class Net(torch.nn.Module):
            def __init__(self):
                super(Net, self).__init__()
                self.linear_1 = torch.nn.Linear(5, 10)
                self.linear_2 = torch.nn.Linear(10, 5)
                # typical actual usage: modules loaded from allenlp.model.load(..)
                self.linear_3_transfer = torch.nn.Linear(5, 10)
                self.linear_4_transfer = torch.nn.Linear(10, 5)
                self.pretrained_conv = torch.nn.Conv1d(5, 5, 5)

            def forward(self, inputs):  # pylint: disable=arguments-differ
                pass

        json_params = """{"initializer": [
        [".*linear.*", {"type": "constant", "val": 10}],
        [".*conv.*", {"type": "constant", "val": 10}],
        [".*_transfer.*", "prevent"],
        [".*pretrained.*",{"type": "prevent"}]
        ]}
        """
        params = Params(json.loads(_jsonnet.evaluate_snippet("", json_params)))
        initializers = InitializerApplicator.from_params(params['initializer'])
        model = Net()
        initializers(model)

        for module in [model.linear_1, model.linear_2]:
            for parameter in module.parameters():
                assert torch.equal(parameter.data,
                                   torch.ones(parameter.size()) * 10)

        transfered_modules = [
            model.linear_3_transfer, model.linear_4_transfer,
            model.pretrained_conv
        ]

        for module in transfered_modules:
            for parameter in module.parameters():
                assert not torch.equal(parameter.data,
                                       torch.ones(parameter.size()) * 10)
Exemplo n.º 26
0
    def test_from_params(self):
        params = Params({
            "regularizers": [("conv", "l1"),
                             ("linear", {
                                 "type": "l2",
                                 "alpha": 10
                             })]
        })
        regularizer_applicator = RegularizerApplicator.from_params(
            params.pop("regularizers"))
        regularizers = regularizer_applicator._regularizers  # pylint: disable=protected-access

        conv = linear = None
        for regex, regularizer in regularizers:
            if regex == "conv":
                conv = regularizer
            elif regex == "linear":
                linear = regularizer

        assert isinstance(conv, L1Regularizer)
        assert isinstance(linear, L2Regularizer)
        assert linear.alpha == 10
Exemplo n.º 27
0
    def test_optimizer_parameter_groups(self):
        optimizer_params = Params({
            "type":
            "sgd",
            "lr":
            1,
            "momentum":
            5,
            "parameter_groups": [
                # the repeated "bias_" checks a corner case
                # NOT_A_VARIABLE_NAME displays a warning but does not raise an exception
                [["weight_i", "bias_", "bias_", "NOT_A_VARIABLE_NAME"], {
                    'lr': 2
                }],
                [["tag_projection_layer"], {
                    'lr': 3
                }],
            ]
        })
        parameters = [[n, p] for n, p in self.model.named_parameters()
                      if p.requires_grad]
        optimizer = Optimizer.from_params(parameters, optimizer_params)
        param_groups = optimizer.param_groups

        assert len(param_groups) == 3
        assert param_groups[0]['lr'] == 2
        assert param_groups[1]['lr'] == 3
        # base case uses default lr
        assert param_groups[2]['lr'] == 1
        for k in range(3):
            assert param_groups[k]['momentum'] == 5

        # all LSTM parameters except recurrent connections (those with weight_h in name)
        assert len(param_groups[0]['params']) == 6
        # just the projection weight and bias
        assert len(param_groups[1]['params']) == 2
        # the embedding + recurrent connections left in the default group
        assert len(param_groups[2]['params']) == 3
Exemplo n.º 28
0
    def test_valid_vocab_extension(self):
        vocab_dir = self.TEST_DIR / 'vocab_save'
        extension_ways = ["from_params", "extend_from_instances"]
        # Test: padded/non-padded common namespaces are extending appropriately
        non_padded_namespaces_list = [[], ["tokens"]]
        for non_padded_namespaces in non_padded_namespaces_list:
            original_vocab = Vocabulary(non_padded_namespaces=non_padded_namespaces)
            original_vocab.add_token_to_namespace("d", namespace="tokens")
            original_vocab.add_token_to_namespace("a", namespace="tokens")
            original_vocab.add_token_to_namespace("b", namespace="tokens")
            text_field = TextField([Token(t) for t in ["a", "d", "c", "e"]],
                                   {"tokens": SingleIdTokenIndexer("tokens")})
            instances = Batch([Instance({"text": text_field})])
            for way in extension_ways:
                if way == "extend_from_instances":
                    extended_vocab = copy.copy(original_vocab)
                    params = Params({"non_padded_namespaces": non_padded_namespaces})
                    extended_vocab.extend_from_instances(params, instances)
                else:
                    shutil.rmtree(vocab_dir, ignore_errors=True)
                    original_vocab.save_to_files(vocab_dir)
                    params = Params({"directory_path": vocab_dir, "extend": True,
                                     "non_padded_namespaces": non_padded_namespaces})
                    extended_vocab = Vocabulary.from_params(params, instances)

                extra_count = 2 if extended_vocab.is_padded("tokens") else 0
                assert extended_vocab.get_token_index("d", "tokens") == 0 + extra_count
                assert extended_vocab.get_token_index("a", "tokens") == 1 + extra_count
                assert extended_vocab.get_token_index("b", "tokens") == 2 + extra_count

                assert extended_vocab.get_token_index("c", "tokens") # should be present
                assert extended_vocab.get_token_index("e", "tokens") # should be present

                assert extended_vocab.get_vocab_size("tokens") == 5 + extra_count

        # Test: padded/non-padded non-common namespaces are extending appropriately
        non_padded_namespaces_list = [[],
                                      ["tokens1"],
                                      ["tokens1", "tokens2"]]
        for non_padded_namespaces in non_padded_namespaces_list:
            original_vocab = Vocabulary(non_padded_namespaces=non_padded_namespaces)
            original_vocab.add_token_to_namespace("a", namespace="tokens1") # index2
            text_field = TextField([Token(t) for t in ["b"]],
                                   {"tokens2": SingleIdTokenIndexer("tokens2")})
            instances = Batch([Instance({"text": text_field})])

            for way in extension_ways:
                if way == "extend_from_instances":
                    extended_vocab = copy.copy(original_vocab)
                    params = Params({"non_padded_namespaces": non_padded_namespaces})
                    extended_vocab.extend_from_instances(params, instances)
                else:
                    shutil.rmtree(vocab_dir, ignore_errors=True)
                    original_vocab.save_to_files(vocab_dir)
                    params = Params({"directory_path": vocab_dir, "extend": True,
                                     "non_padded_namespaces": non_padded_namespaces})
                    extended_vocab = Vocabulary.from_params(params, instances)

                # Should have two namespaces
                assert len(extended_vocab._token_to_index) == 2

                extra_count = 2 if extended_vocab.is_padded("tokens1") else 0
                assert extended_vocab.get_vocab_size("tokens1") == 1 + extra_count

                extra_count = 2 if extended_vocab.is_padded("tokens2") else 0
                assert extended_vocab.get_vocab_size("tokens2") == 1 + extra_count
Exemplo n.º 29
0
    def test_from_params_valid_vocab_extension_thoroughly(self):
        '''
        Tests for Valid Vocab Extension thoroughly: Vocab extension is valid
        when overlapping namespaces have same padding behaviour (padded/non-padded)
        Summary of namespace paddings in this test:
        original_vocab namespaces
            tokens0     padded
            tokens1     non-padded
            tokens2     padded
            tokens3     non-padded
        instances namespaces
            tokens0     padded
            tokens1     non-padded
            tokens4     padded
            tokens5     non-padded
        TypicalExtention example: (of tokens1 namespace)
        -> original_vocab index2token
           apple          #0->apple
           bat            #1->bat
           cat            #2->cat
        -> Token to be extended with: cat, an, apple, banana, atom, bat
        -> extended_vocab: index2token
           apple           #0->apple
           bat             #1->bat
           cat             #2->cat
           an              #3->an
           atom            #4->atom
           banana          #5->banana
        '''

        vocab_dir = self.TEST_DIR / 'vocab_save'
        original_vocab = Vocabulary(non_padded_namespaces=["tokens1", "tokens3"])
        original_vocab.add_token_to_namespace("apple", namespace="tokens0") # index:2
        original_vocab.add_token_to_namespace("bat", namespace="tokens0")   # index:3
        original_vocab.add_token_to_namespace("cat", namespace="tokens0")   # index:4

        original_vocab.add_token_to_namespace("apple", namespace="tokens1") # index:0
        original_vocab.add_token_to_namespace("bat", namespace="tokens1")   # index:1
        original_vocab.add_token_to_namespace("cat", namespace="tokens1")   # index:2

        original_vocab.add_token_to_namespace("a", namespace="tokens2") # index:0
        original_vocab.add_token_to_namespace("b", namespace="tokens2") # index:1
        original_vocab.add_token_to_namespace("c", namespace="tokens2") # index:2

        original_vocab.add_token_to_namespace("p", namespace="tokens3") # index:0
        original_vocab.add_token_to_namespace("q", namespace="tokens3") # index:1

        original_vocab.save_to_files(vocab_dir)

        text_field0 = TextField([Token(t) for t in ["cat", "an", "apple", "banana", "atom", "bat"]],
                                {"tokens0": SingleIdTokenIndexer("tokens0")})
        text_field1 = TextField([Token(t) for t in ["cat", "an", "apple", "banana", "atom", "bat"]],
                                {"tokens1": SingleIdTokenIndexer("tokens1")})
        text_field4 = TextField([Token(t) for t in ["l", "m", "n", "o"]],
                                {"tokens4": SingleIdTokenIndexer("tokens4")})
        text_field5 = TextField([Token(t) for t in ["x", "y", "z"]],
                                {"tokens5": SingleIdTokenIndexer("tokens5")})
        instances = Batch([Instance({"text0": text_field0, "text1": text_field1,
                                     "text4": text_field4, "text5": text_field5})])

        params = Params({"directory_path": vocab_dir,
                         "extend": True,
                         "non_padded_namespaces": ["tokens1", "tokens5"]})
        extended_vocab = Vocabulary.from_params(params, instances)

        # namespaces: tokens0, tokens1 is common.
        # tokens2, tokens3 only vocab has. tokens4, tokens5 only instances
        extended_namespaces = {*extended_vocab._token_to_index}
        assert extended_namespaces == {"tokens{}".format(i) for i in range(6)}

        # # Check that _non_padded_namespaces list is consistent after extension
        assert extended_vocab._non_padded_namespaces == {"tokens1", "tokens3", "tokens5"}

        # # original_vocab["tokens1"] has 3 tokens, instances of "tokens1" ns has 5 tokens. 2 overlapping
        assert extended_vocab.get_vocab_size("tokens1") == 6
        assert extended_vocab.get_vocab_size("tokens0") == 8 # 2 extra overlapping because padded

        # namespace tokens3, tokens4 was only in original_vocab,
        # and its token count should be same in extended_vocab
        assert extended_vocab.get_vocab_size("tokens2") == original_vocab.get_vocab_size("tokens2")
        assert extended_vocab.get_vocab_size("tokens3") == original_vocab.get_vocab_size("tokens3")

        # namespace tokens2 was only in instances,
        # and its token count should be same in extended_vocab
        assert extended_vocab.get_vocab_size("tokens4") == 6 # l,m,n,o + oov + padding
        assert extended_vocab.get_vocab_size("tokens5") == 3 # x,y,z

        # Word2index mapping of all words in all namespaces of original_vocab
        # should be maintained in extended_vocab
        for namespace, token2index in original_vocab._token_to_index.items():
            for token, _ in token2index.items():
                vocab_index = original_vocab.get_token_index(token, namespace)
                extended_vocab_index = extended_vocab.get_token_index(token, namespace)
                assert vocab_index == extended_vocab_index
        # And same for Index2Word mapping
        for namespace, index2token in original_vocab._index_to_token.items():
            for index, _ in index2token.items():
                vocab_token = original_vocab.get_token_from_index(index, namespace)
                extended_vocab_token = extended_vocab.get_token_from_index(index, namespace)
                assert vocab_token == extended_vocab_token
 def test_locally_normalised_span_extractor_can_build_from_params(self):
     params = Params({"type": "self_attentive", "input_dim": 5})
     extractor = SpanExtractor.from_params(params)
     assert isinstance(extractor, SelfAttentiveSpanExtractor)