Ejemplo n.º 1
0
    def test_trainer_can_run_multiple_gpu(self):

        class MetaDataCheckWrapper(Model):
            """
            Checks that the metadata field has been correctly split across the batch dimension
            when running on multiple gpus.
            """
            def __init__(self, model):
                super().__init__(model.vocab)
                self.model = model

            def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore # pylint: disable=arguments-differ
                assert 'metadata' in kwargs and 'tags' in kwargs, \
                    f'tokens and metadata must be provided. Got {kwargs.keys()} instead.'
                batch_size = kwargs['tokens']['tokens'].size()[0]
                assert len(kwargs['metadata']) == batch_size, \
                    f'metadata must be split appropriately. Expected {batch_size} elements, ' \
                    f"got {len(kwargs['metadata'])} elements."
                return self.model.forward(**kwargs)

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = Trainer(MetaDataCheckWrapper(self.model), self.optimizer,
                          multigpu_iterator, self.instances, num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()
Ejemplo n.º 2
0
 def test_trainer_can_run_multiple_gpu(self):
     multigpu_iterator = BasicIterator(batch_size=4)
     multigpu_iterator.index_with(self.vocab)
     trainer = Trainer(self.model, self.optimizer,
                       multigpu_iterator, self.instances, num_epochs=2,
                       cuda_device=[0, 1])
     trainer.train()
Ejemplo n.º 3
0
 def test_can_optimise_model_with_dense_and_sparse_params(self):
     optimizer_params = Params({
             "type": "dense_sparse_adam"
     })
     parameters = [[n, p] for n, p in self.model.named_parameters() if p.requires_grad]
     optimizer = Optimizer.from_params(parameters, optimizer_params)
     iterator = BasicIterator(2)
     iterator.index_with(self.vocab)
     Trainer(self.model, optimizer, iterator, self.instances).train()
Ejemplo n.º 4
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[0], self.instances[1]],
                                      [self.instances[2], self.instances[3]],
                                      [self.instances[4]]]
Ejemplo n.º 5
0
    def test_from_params(self):
        # pylint: disable=protected-access
        params = Params({})
        iterator = BasicIterator.from_params(params)
        assert iterator._batch_size == 32  # default value

        params = Params({"batch_size": 10})
        iterator = BasicIterator.from_params(params)
        assert iterator._batch_size == 10
Ejemplo n.º 6
0
    def test_epoch_tracking_multiple_epochs(self):
        iterator = BasicIterator(batch_size=2, track_epoch=True)
        iterator.index_with(self.vocab)

        all_batches = list(iterator(self.instances, num_epochs=10))
        assert len(all_batches) == 10 * 3
        for i, batch in enumerate(all_batches):
            # Should have 3 batches per epoch
            epoch = i // 3
            assert all(epoch_num == epoch for epoch_num in batch['epoch_num'])
Ejemplo n.º 7
0
 def test_max_instances_in_memory(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, max_instances_in_memory=3)
         # One epoch: 5 instances -> [2, 1, 2]
         batches = list(iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[self.instances[0], self.instances[1]],
                                      [self.instances[2]],
                                      [self.instances[3], self.instances[4]]]
Ejemplo n.º 8
0
    def test_trainer_can_log_learning_rates_tensorboard(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          should_log_learning_rate=True,
                          summary_interval=2)

        trainer.train()
Ejemplo n.º 9
0
 def test_yield_one_epoch_iterates_over_the_data_once(self):
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         iterator.index_with(self.vocab)
         batches = list(iterator(test_instances, num_epochs=1))
         # We just want to get the single-token array for the text field in the instance.
         instances = [tuple(instance.detach().cpu().numpy())
                      for batch in batches
                      for instance in batch['text']["tokens"]]
         assert len(instances) == 5
         self.assert_instances_are_correct(instances)
Ejemplo n.º 10
0
 def test_call_iterates_over_data_forever(self):
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         iterator.index_with(self.vocab)
         generator = iterator(test_instances)
         batches = [next(generator) for _ in range(18)]  # going over the data 6 times
         # We just want to get the single-token array for the text field in the instance.
         instances = [tuple(instance.detach().cpu().numpy())
                      for batch in batches
                      for instance in batch['text']["tokens"]]
         assert len(instances) == 5 * 6
         self.assert_instances_are_correct(instances)
    def test_with_iterator(self):
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        instances = reader.read(self.glob)

        iterator = BasicIterator(batch_size=32)
        iterator.index_with(self.vocab)

        batches = [batch for batch in iterator(instances, num_epochs=1)]

        # 400 instances / batch_size 32 = 12 full batches + 1 batch of 16
        sizes = sorted([len(batch['tags']) for batch in batches])
        assert sizes == [16] + 12 * [32]
Ejemplo n.º 12
0
    def test_epoch_tracking_forever(self):
        iterator = BasicIterator(batch_size=2, track_epoch=True)
        iterator.index_with(self.vocab)

        it = iterator(self.instances, num_epochs=None)

        all_batches = [next(it) for _ in range(30)]

        assert len(all_batches) == 30
        for i, batch in enumerate(all_batches):
            # Should have 3 batches per epoch
            epoch = i // 3
            assert all(epoch_num == epoch for epoch_num in batch['epoch_num'])
Ejemplo n.º 13
0
    def test_elmo_bilm(self):
        # get the raw data
        sentences, expected_lm_embeddings = self._load_sentences_embeddings()

        # load the test model
        elmo_bilm = _ElmoBiLm(self.options_file, self.weight_file)

        # Deal with the data.
        indexer = ELMoTokenCharactersIndexer()

        # For each sentence, first create a TextField, then create an instance
        instances = []
        for batch in zip(*sentences):
            for sentence in batch:
                tokens = [Token(token) for token in sentence.split()]
                field = TextField(tokens, {'character_ids': indexer})
                instance = Instance({"elmo": field})
                instances.append(instance)

        vocab = Vocabulary()

        # Now finally we can iterate through batches.
        iterator = BasicIterator(3)
        iterator.index_with(vocab)
        for i, batch in enumerate(iterator(instances, num_epochs=1, shuffle=False)):
            lm_embeddings = elmo_bilm(batch['elmo']['character_ids'])
            top_layer_embeddings, mask = remove_sentence_boundaries(
                    lm_embeddings['activations'][2],
                    lm_embeddings['mask']
            )

            # check the mask lengths
            lengths = mask.data.numpy().sum(axis=1)
            batch_sentences = [sentences[k][i] for k in range(3)]
            expected_lengths = [
                    len(sentence.split()) for sentence in batch_sentences
            ]
            self.assertEqual(lengths.tolist(), expected_lengths)

            # get the expected embeddings and compare!
            expected_top_layer = [expected_lm_embeddings[k][i] for k in range(3)]
            for k in range(3):
                self.assertTrue(
                        numpy.allclose(
                                top_layer_embeddings[k, :lengths[k], :].data.numpy(),
                                expected_top_layer[k],
                                atol=1.0e-6
                        )
                )
def main(serialization_directory, device):
    """
    serialization_directory : str, required.
        The directory containing the serialized weights.
    device: int, default = -1
        The device to run the evaluation on.
    """

    config = Params.from_file(os.path.join(serialization_directory, "config.json"))
    dataset_reader = DatasetReader.from_params(config['dataset_reader'])
    evaluation_data_path = config['validation_data_path']

    model = Model.load(config, serialization_dir=serialization_directory, cuda_device=device)

    prediction_file_path = os.path.join(serialization_directory, "predictions.txt")
    gold_file_path = os.path.join(serialization_directory, "gold.txt")
    prediction_file = open(prediction_file_path, "w+")
    gold_file = open(gold_file_path, "w+")

    # Load the evaluation data and index it.
    print("Reading evaluation data from {}".format(evaluation_data_path))
    instances = dataset_reader.read(evaluation_data_path)
    iterator = BasicIterator(batch_size=32)
    iterator.index_with(model.vocab)

    model_predictions = []
    batches = iterator(instances, num_epochs=1, shuffle=False, cuda_device=device, for_training=False)
    for batch in Tqdm.tqdm(batches):
        result = model(**batch)
        predictions = model.decode(result)
        model_predictions.extend(predictions["tags"])

    for instance, prediction in zip(instances, model_predictions):
        fields = instance.fields
        try:
            # Most sentences have a verbal predicate, but not all.
            verb_index = fields["verb_indicator"].labels.index(1)
        except ValueError:
            verb_index = None

        gold_tags = fields["tags"].labels
        sentence = fields["tokens"].tokens

        write_to_conll_eval_file(prediction_file, gold_file,
                                 verb_index, sentence, prediction, gold_tags)
    prediction_file.close()
    gold_file.close()
Ejemplo n.º 15
0
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances))
        lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances))

        eager_instances1 = self.instances[:]
        eager_instances2 = self.instances[:]

        for instances1, instances2 in [(eager_instances1, eager_instances2),
                                       (lazy_instances1, lazy_instances2)]:
            iterator = BasicIterator(batch_size=1, instances_per_epoch=2)
            iterator.index_with(self.vocab)

            # First epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # First epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]], [self.instances[1]]]

            # Second epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]

            # Second epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]], [self.instances[3]]]
Ejemplo n.º 16
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                    batch_size=3, maximum_samples_per_batch=['num_tokens', 9]
            )
            iterator.index_with(self.vocab)
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            stats = self.get_batches_stats(batches)

            # ensure all instances are in a batch
            assert stats['total_instances'] == len(self.instances)

            # ensure correct batch sizes
            assert stats['batch_lengths'] == [2, 1, 1, 1]

            # ensure correct sample sizes (<= 9)
            assert stats['sample_sizes'] == [8, 3, 9, 1]
Ejemplo n.º 17
0
    def test_shuffle(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(id(instance) for batch in in_order_batches for instance in batch)
            shuffled_counts = Counter(id(instance) for batch in shuffled_batches for instance in batch)
            assert in_order_counts == shuffled_counts
Ejemplo n.º 18
0
    def test_many_instances_per_epoch(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):
            iterator = BasicIterator(batch_size=2, instances_per_epoch=7)
            # First epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0], self.instances[1]],
                                         [self.instances[2], self.instances[3]],
                                         [self.instances[4], self.instances[0]],
                                         [self.instances[1]]]

            # Second epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2], self.instances[3]],
                                         [self.instances[4], self.instances[0]],
                                         [self.instances[1], self.instances[2]],
                                         [self.instances[3]]]
Ejemplo n.º 19
0
 def test_multiprocess_iterate_partial_does_not_hang(self):
     for test_instances in (self.instances, self.lazy_instances):
         base_iterator = BasicIterator(batch_size=2,
                                       max_instances_in_memory=1024)
         iterator = MultiprocessIterator(base_iterator, num_workers=4)
         iterator.index_with(self.vocab)
         generator = iterator(test_instances, num_epochs=1)
         # We only iterate through 3 of the 5 instances causing the
         # processes generating the tensors to remain active.
         for _ in range(3):
             next(generator)
Ejemplo n.º 20
0
 def test_yield_one_epoch_iterates_over_the_data_once(self):
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(iterator(test_instances, num_epochs=1))
         # We just want to get the single-token array for the text field in the instance.
         instances = [
             tuple(instance.detach().cpu().numpy()) for batch in batches
             for instance in batch['text']["tokens"]
         ]
         assert len(instances) == 5
         self.assert_instances_are_correct(instances)
Ejemplo n.º 21
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                    batch_size=3, maximum_samples_per_batch=['num_tokens', 9]
            )
            batches = list(iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max(
                        [instance.get_padding_lengths()['text']['num_tokens']
                         for instance in batch.instances]
                )
                assert batch_sequence_length * len(batch.instances) <= 9
Ejemplo n.º 22
0
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            dpd_output_directory=wikitables_dir + 'dpd_output/')
        instances = wikitables_reader.read(wikitables_dir +
                                           'sample_data.examples')
        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)
        trainer = Trainer(model,
                          self.optimizer,
                          multigpu_iterator,
                          instances,
                          num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()
Ejemplo n.º 23
0
    def test_maximum_samples_per_batch_packs_tightly(self):
        # pylint: disable=protected-access
        token_counts = [10, 4, 3]
        test_instances = self.create_instances_from_token_counts(token_counts)

        iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 11]
        )
        iterator.index_with(self.vocab)
        batches = list(iterator._create_batches(test_instances, shuffle=False))
        stats = self.get_batches_stats(batches)

        # ensure all instances are in a batch
        assert stats['total_instances'] == len(token_counts)

        # ensure correct batch sizes
        assert stats['batch_lengths'] == [1, 2]

        # ensure correct sample sizes (<= 11)
        assert stats['sample_sizes'] == [10, 8]
    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an iterator that sleeps for 2.5 second per epoch, so the total training
        #       time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.
        class WaitingIterator(BasicIterator):
            def _create_batches(self, *args, **kwargs):
                time.sleep(2.5)
                return super()._create_batches(*args, **kwargs)

        waiting_iterator = WaitingIterator(batch_size=2)
        waiting_iterator.index_with(self.vocab)

        # Don't want validation iterator to wait.
        viterator = BasicIterator(batch_size=2)
        viterator.index_with(self.vocab)

        trainer = CallbackTrainer(
            self.model,
            training_data=self.instances,
            iterator=waiting_iterator,
            optimizer=self.optimizer,
            num_epochs=6,
            serialization_dir=self.TEST_DIR,
            callbacks=self.default_callbacks(max_checkpoints=2,
                                             checkpoint_every=5,
                                             validation_iterator=viterator),
        )
        trainer.train()

        # Now check the serialized files
        for prefix in ["model_state_epoch_*", "training_state_epoch_*"]:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]
    def setUp(self):
        super().setUp()

        # A lot of the tests want access to the metric tracker
        # so we add a property that gets it by grabbing it from
        # the relevant callback.
        def metric_tracker(self: CallbackTrainer):
            for callback in self.handler.callbacks():
                if isinstance(callback, TrackMetrics):
                    return callback.metric_tracker
            return None

        setattr(CallbackTrainer, "metric_tracker", property(metric_tracker))

        self.instances = SequenceTaggingDatasetReader().read(
            self.FIXTURES_ROOT / "data" / "sequence_tagging.tsv")
        vocab = Vocabulary.from_instances(self.instances)
        self.vocab = vocab
        self.model_params = Params({
            "text_field_embedder": {
                "token_embedders": {
                    "tokens": {
                        "type": "embedding",
                        "embedding_dim": 5
                    }
                }
            },
            "encoder": {
                "type": "lstm",
                "input_size": 5,
                "hidden_size": 7,
                "num_layers": 2
            },
        })
        self.model = SimpleTagger.from_params(vocab=self.vocab,
                                              params=self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(),
                                         0.01,
                                         momentum=0.9)
        self.iterator = BasicIterator(batch_size=2)
        self.iterator.index_with(vocab)
Ejemplo n.º 26
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 9])
            iterator.index_with(self.vocab)
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max([
                    instance.get_padding_lengths()['text']['num_tokens']
                    for instance in batch.instances
                ])
                assert batch_sequence_length * len(batch.instances) <= 9
Ejemplo n.º 27
0
    def test_kg_reader_with_eval(self):
        train_file = 'tests/fixtures/kg_embeddings/wn18rr_train.txt'
        dev_file = 'tests/fixtures/kg_embeddings/wn18rr_dev.txt'

        train_instances = KGTupleReader().read(train_file)

        reader = KGTupleReader(extra_files_for_gold_pairs=[train_file])
        instances = reader.read(dev_file)
        self.assertEqual(len(instances), 2)

        vocab = Vocabulary.from_params(Params({}), train_instances + instances)
        iterator = BasicIterator(batch_size=32)
        iterator.index_with(vocab)
        for batch in iterator(instances, num_epochs=1, shuffle=False):
            pass

        expected_entity = [1, 5]
        expected_relation = ['_hypernym', '_hypernym_reverse']
        expected_entity2 = [[5, 2, 3], [1, 4]]
        self._check_batch(batch, vocab,
                          expected_entity, expected_relation, expected_entity2)
Ejemplo n.º 28
0
    def test_shuffle(self):
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(
                id(instance) for batch in in_order_batches for instance in batch
            )
            shuffled_counts = Counter(
                id(instance) for batch in shuffled_batches for instance in batch
            )
            assert in_order_counts == shuffled_counts
Ejemplo n.º 29
0
 def test_call_iterates_over_data_forever(self):
     for test_instances in (self.instances, self.lazy_instances):
         generator = BasicIterator(batch_size=2)(test_instances)
         batches = [next(generator)
                    for _ in range(18)]  # going over the data 6 times
         # We just want to get the single-token array for the text field in the instance.
         instances = [
             tuple(instance.detach().cpu().numpy()) for batch in batches
             for instance in batch['text']["tokens"]
         ]
         assert len(instances) == 5 * 6
         self.assert_instances_are_correct(instances)
Ejemplo n.º 30
0
    def test_elmo_bilm(self):
        # get the raw data
        sentences, expected_lm_embeddings = self._load_sentences_embeddings()

        # load the test model
        options_file = os.path.join(FIXTURES, 'options.json')
        weight_file = os.path.join(FIXTURES, 'lm_weights.hdf5')
        elmo_bilm = _ElmoBiLm(options_file, weight_file)

        # Deal with the data.
        indexer = ELMoTokenCharactersIndexer()

        # For each sentence, first create a TextField, then create an instance
        instances = []
        for batch in zip(*sentences):
            for sentence in batch:
                tokens = [Token(token) for token in sentence.split()]
                field = TextField(tokens, {'character_ids': indexer})
                instance = Instance({"elmo": field})
                instances.append(instance)

        dataset = Dataset(instances)
        vocab = Vocabulary()
        dataset.index_instances(vocab)

        # Now finally we can iterate through batches.
        iterator = BasicIterator(3)
        for i, batch in enumerate(iterator(dataset, num_epochs=1, shuffle=False)):
            batch_tensor = Variable(torch.from_numpy(batch['elmo']['character_ids']))
            lm_embeddings = elmo_bilm(batch_tensor)
            top_layer_embeddings, mask = remove_sentence_boundaries(
                    lm_embeddings['activations'][2],
                    lm_embeddings['mask']
            )

            # check the mask lengths
            lengths = mask.data.numpy().sum(axis=1)
            batch_sentences = [sentences[k][i] for k in range(3)]
            expected_lengths = [
                    len(sentence.split()) for sentence in batch_sentences
            ]
            self.assertEqual(lengths.tolist(), expected_lengths)

            # get the expected embeddings and compare!
            expected_top_layer = [expected_lm_embeddings[k][i] for k in range(3)]
            for k in range(3):
                self.assertTrue(
                        numpy.allclose(
                                top_layer_embeddings[k, :lengths[k], :].data.numpy(),
                                expected_top_layer[k],
                                atol=1.0e-6
                        )
                )
Ejemplo n.º 31
0
class InKBAllEntitiesEncoder:
    def __init__(self, args, entity_loader_datasetreaderclass, entity_encoder_wrapping_model, vocab):
        self.args = args
        self.entity_loader_datasetreader = entity_loader_datasetreaderclass
        self.sequence_iterator_for_encoding_entities = BasicIterator(batch_size=args.batch_size_for_kb_encoder)
        self.vocab = vocab
        self.entity_encoder_wrapping_model = entity_encoder_wrapping_model
        self.entity_encoder_wrapping_model.eval()
        self.cuda_device = 0

    def encoding_all_entities(self):
        duidx2emb = {}
        ds = self.entity_loader_datasetreader.read('test')
        self.sequence_iterator_for_encoding_entities.index_with(self.vocab)
        entity_generator = self.sequence_iterator_for_encoding_entities(ds, num_epochs=1, shuffle=False)
        entity_generator_tqdm = tqdm(entity_generator, total=self.sequence_iterator_for_encoding_entities.get_num_batches(ds))
        print('======Encoding all entites from title and description=====')
        
        entities_full_path = os.path.join(self.args.entities_path, self.args.entities_filename)
        if self.args.load_entities:
            duidx2emb = pickle_load_object(entities_full_path)
        else:
            with torch.no_grad():
                for batch in entity_generator_tqdm:
                    batch = nn_util.move_to_device(batch, self.cuda_device)
                    duidxs, embs = self._extract_cuidx_and_its_encoded_emb(batch)
                    for duidx, emb in zip(duidxs, embs):
                        duidx2emb.update({int(duidx):emb})
            if self.args.save_entities:
                pickle_save_object(duidx2emb, entities_full_path)

        return duidx2emb

    def tonp(self, tsr):
        return tsr.detach().cpu().numpy()

    def _extract_cuidx_and_its_encoded_emb(self, batch) -> np.ndarray:
        out_dict = self.entity_encoder_wrapping_model(**batch)
        return self.tonp(out_dict['gold_duidx']), self.tonp(out_dict['emb_of_entities_encoded'])
    def test_multiprocess_reader_with_multiprocess_iterator(self):
        # use SequenceTaggingDatasetReader as the base reader
        reader = MultiprocessDatasetReader(base_reader=self.base_reader, num_workers=2)
        base_iterator = BasicIterator(batch_size=32, max_instances_in_memory=1024)

        iterator = MultiprocessIterator(base_iterator, num_workers=2)
        iterator.index_with(self.vocab)

        instances = reader.read(self.glob)

        tensor_dicts = iterator(instances, num_epochs=1)
        sizes = [len(tensor_dict["tags"]) for tensor_dict in tensor_dicts]
        assert sum(sizes) == 400
 def test_yield_one_epoch_iterates_over_the_data_once(self):
     for test_instances in (self.instances, self.lazy_instances):
         base_iterator = BasicIterator(batch_size=2, max_instances_in_memory=1024)
         iterator = MultiprocessIterator(base_iterator, num_workers=4)
         iterator.index_with(self.vocab)
         batches = list(iterator(test_instances, num_epochs=1))
         # We just want to get the single-token array for the text field in the instance.
         instances = [
             tuple(instance.detach().cpu().numpy())
             for batch in batches
             for instance in batch["text"]["tokens"]["tokens"]
         ]
         assert len(instances) == 5
Ejemplo n.º 34
0
def get_loss_per_candidate_squad(
    index,
    model,
    trigger_token_ids,
    cand_trigger_token_ids,
    vocab,
    dev_dataset,
    span_start,
    span_end,
):
    """
    Similar to get_loss_per_candidate, except that we use multiple batches (in this case 4) rhater than one
    to evaluate the top trigger token candidates.
    """
    if isinstance(cand_trigger_token_ids[0], (numpy.int64, int)):
        print("Only 1 candidate for index detected, not searching")
        return trigger_token_ids
    model.get_metrics(reset=True)
    loss_per_candidate = []
    iterator = BasicIterator(batch_size=32)
    batch_count = 0
    curr_loss = 0.0
    for batch in lazy_groups_of(iterator(dev_dataset,
                                         num_epochs=1,
                                         shuffle=True),
                                group_size=1):
        if batch_count > 4:
            continue
        batch_count = batch_count + 1
        curr_loss += (evaluate_batch_squad(
            model, batch, trigger_token_ids, vocab, span_start,
            span_end)["loss"].cpu().detach().numpy())
    loss_per_candidate.append((deepcopy(trigger_token_ids), curr_loss))

    for cand_id in range(len(cand_trigger_token_ids[0])):
        temp_trigger_token_ids = deepcopy(trigger_token_ids)
        temp_trigger_token_ids[index] = cand_trigger_token_ids[index][cand_id]
        loss = 0
        batch_count = 0
        for batch in lazy_groups_of(iterator(dev_dataset,
                                             num_epochs=1,
                                             shuffle=True),
                                    group_size=1):
            if batch_count > 4:
                continue
            batch_count = batch_count + 1
            loss += (evaluate_batch_squad(
                model, batch, temp_trigger_token_ids, vocab, span_start,
                span_end)["loss"].cpu().detach().numpy())
        loss_per_candidate.append((deepcopy(temp_trigger_token_ids), loss))
    return loss_per_candidate
Ejemplo n.º 35
0
    def test_trainer_can_run_multiple_gpu(self):
        self.model.cuda()

        class MetaDataCheckWrapper(Model):
            """
            Checks that the metadata field has been correctly split across the batch dimension
            when running on multiple gpus.
            """
            def __init__(self, model):
                super().__init__(model.vocab)
                self.model = model

            def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore # pylint: disable=arguments-differ
                assert 'metadata' in kwargs and 'tags' in kwargs, \
                    f'tokens and metadata must be provided. Got {kwargs.keys()} instead.'
                batch_size = kwargs['tokens']['tokens'].size()[0]
                assert len(kwargs['metadata']) == batch_size, \
                    f'metadata must be split appropriately. Expected {batch_size} elements, ' \
                    f"got {len(kwargs['metadata'])} elements."
                return self.model.forward(**kwargs)

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = CallbackTrainer(MetaDataCheckWrapper(self.model),
                                  training_data=self.instances,
                                  iterator=multigpu_iterator,
                                  optimizer=self.optimizer,
                                  num_epochs=2,
                                  callbacks=self.default_callbacks(),
                                  cuda_device=[0, 1])
        metrics = trainer.train()
        assert 'peak_cpu_memory_MB' in metrics
        assert isinstance(metrics['peak_cpu_memory_MB'], float)
        assert metrics['peak_cpu_memory_MB'] > 0
        assert 'peak_gpu_0_memory_MB' in metrics
        assert isinstance(metrics['peak_gpu_0_memory_MB'], int)
        assert 'peak_gpu_1_memory_MB' in metrics
        assert isinstance(metrics['peak_gpu_1_memory_MB'], int)
Ejemplo n.º 36
0
    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
                  for fname in file_names]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
            os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model, self.optimizer,
                                  self.iterator, self.instances, num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint() # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2 # pylint: disable=protected-access
Ejemplo n.º 37
0
    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
                  for fname in file_names]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
            os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model, self.optimizer,
                                  self.iterator, self.instances, num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2  # pylint: disable=protected-access
Ejemplo n.º 38
0
def pipeline_function_list(item_list, model, vocab, cursor):
    seed = 12
    batch_size = 128
    dev_batch_size = 128
    lazy = True
    torch.manual_seed(seed)
    contain_first_sentence = True

    # Prepare Data
    token_indexers = {
        'tokens':
        SingleIdTokenIndexer(namespace='tokens'),  # This is the raw tokens
        'elmo_chars': ELMoTokenCharactersIndexer(
            namespace='elmo_characters')  # This is the elmo_characters
    }

    dev_fever_data_reader = SSelectorReader(token_indexers=token_indexers,
                                            lazy=lazy,
                                            max_l=180)

    complete_upstream_dev_data = disamb.sample_disamb_inference(
        item_list, cursor, contain_first_sentence=contain_first_sentence)
    dev_instances = dev_fever_data_reader.read(complete_upstream_dev_data)

    # Load Vocabulary
    biterator = BasicIterator(batch_size=batch_size)
    dev_biterator = BasicIterator(batch_size=dev_batch_size)

    # THis is important
    vocab.add_token_to_namespace("true", namespace="selection_labels")
    vocab.add_token_to_namespace("false", namespace="selection_labels")
    vocab.add_token_to_namespace("hidden", namespace="selection_labels")
    vocab.change_token_with_index_to_namespace("hidden",
                                               -2,
                                               namespace='selection_labels')
    # Label value

    vocab.get_index_to_token_vocabulary('selection_labels')

    biterator.index_with(vocab)
    dev_biterator.index_with(vocab)

    eval_iter = dev_biterator(dev_instances, shuffle=False, num_epochs=1)
    complete_upstream_dev_data = hidden_eval(model, eval_iter,
                                             complete_upstream_dev_data)

    dev_doc_score_list = complete_upstream_dev_data
    return dev_doc_score_list
Ejemplo n.º 39
0
    def test_many_instances_per_epoch(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):
            iterator = BasicIterator(batch_size=2, instances_per_epoch=7)
            # First epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[
                self.instances[0], self.instances[1]
            ], [self.instances[2],
                self.instances[3]], [self.instances[4], self.instances[0]],
                                         [self.instances[1]]]

            # Second epoch: 7 instances -> [2, 2, 2, 1]
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[
                self.instances[2], self.instances[3]
            ], [self.instances[4],
                self.instances[0]], [self.instances[1], self.instances[2]],
                                         [self.instances[3]]]
Ejemplo n.º 40
0
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
        offline_lf_directory = wikitables_dir + 'action_space_walker_output/'
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            offline_logical_forms_directory=offline_lf_directory)
        instances = wikitables_reader.read(wikitables_dir +
                                           'sample_data.examples')
        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)

        trainer = CallbackTrainer(model,
                                  instances,
                                  multigpu_iterator,
                                  self.optimizer,
                                  num_epochs=2,
                                  cuda_device=[0, 1],
                                  callbacks=[GradientNormAndClip()])
        trainer.train()
Ejemplo n.º 41
0
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = "allennlp/tests/fixtures/data/wikitables/"
        search_output_directory = wikitables_dir + "action_space_walker_output/"
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir, offline_logical_forms_directory=search_output_directory
        )
        instances = wikitables_reader.read(wikitables_dir + "sample_data.examples")
        archive_path = (
            self.FIXTURES_ROOT
            / "semantic_parsing"
            / "wikitables"
            / "serialization"
            / "model.tar.gz"
        )
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)
        trainer = Trainer(
            model, self.optimizer, multigpu_iterator, instances, num_epochs=2, cuda_device=[0, 1]
        )
        trainer.train()
Ejemplo n.º 42
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read('tests/fixtures/data/sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
             "text_field_embedder": {
                     "tokens": {
                             "type": "embedding",
                             "embedding_dim": 5
                             }
                     },
             "stacked_encoder": {
                     "type": "lstm",
                     "input_size": 5,
                     "hidden_size": 7,
                     "num_layers": 2
                     }
             })
     self.model = SimpleTagger.from_params(self.vocab, self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
Ejemplo n.º 43
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / u'data' / u'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
             u"text_field_embedder": {
                     u"tokens": {
                             u"type": u"embedding",
                             u"embedding_dim": 5
                             }
                     },
             u"encoder": {
                     u"type": u"lstm",
                     u"input_size": 5,
                     u"hidden_size": 7,
                     u"num_layers": 2
                     }
             })
     self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
Ejemplo n.º 44
0
def test_iterator():
    indexer = StaticFasttextTokenIndexer(
        model_path="./data/fasttext_embedding.model",
        model_params_path="./data/fasttext_embedding.model.params")

    loader = MenionsLoader(
        category_mapping_file='./data/test_category_mapping.json',
        token_indexers={"tokens": indexer},
        tokenizer=WordTokenizer(word_splitter=FastSplitter()))

    vocab = Vocabulary.from_params(Params({"directory_path":
                                           "./data/vocab2/"}))

    iterator = BasicIterator(batch_size=32)

    iterator.index_with(vocab)

    limit = 50
    for _ in tqdm.tqdm(iterator(loader.read('./data/train_data_aa.tsv'),
                                num_epochs=1),
                       mininterval=2):
        limit -= 1
        if limit <= 0:
            break
Ejemplo n.º 45
0
    def __init__(self, reader, train_path, test_path, batch_dims):
        train_batch_dim, test_batch_dim = batch_dims

        train_dataset = reader.read(train_path)
        test_dataset = reader.read(test_path)
        vocab = Vocabulary.from_instances(train_dataset)

        train_iterator = BasicIterator(batch_size=train_batch_dim)
        train_iterator.index_with(vocab)
        test_iterator = BasicIterator(batch_size=test_batch_dim)
        test_iterator.index_with(vocab)

        self._label_map = vocab._index_to_token["labels"]
        self._iterators = {
            "train": (train_iterator, train_dataset),
            "test": (test_iterator, test_dataset)
        }
Ejemplo n.º 46
0
def main():
    all_chars = {END_SYMBOL, START_SYMBOL}
    all_chars.update("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ .,!?'-")
    token_counts = {char: 1 for char in all_chars}
    vocab = Vocabulary({'tokens': token_counts})

    token_indexers = {'tokens': SingleIdTokenIndexer()}

    train_set = read_dataset(all_chars)
    instances = [tokens_to_lm_instance(tokens, token_indexers)
                 for tokens in train_set]

    token_embedding = Embedding(num_embeddings=vocab.get_vocab_size('tokens'),
                                embedding_dim=EMBEDDING_SIZE)
    embedder = BasicTextFieldEmbedder({"tokens": token_embedding})

    model = RNNLanguageModel(embedder=embedder,
                             hidden_size=HIDDEN_SIZE,
                             max_len=80,
                             vocab=vocab)

    iterator = BasicIterator(batch_size=BATCH_SIZE)
    iterator.index_with(vocab)

    optimizer = optim.Adam(model.parameters(), lr=5.e-3)

    trainer = Trainer(model=model,
                      optimizer=optimizer,
                      iterator=iterator,
                      train_dataset=instances,
                      num_epochs=10)
    trainer.train()

    for _ in range(50):
        tokens, _ = model.generate()
        print(''.join(token.text for token in tokens))
Ejemplo n.º 47
0
 def __init__(self,
              data_dir,
              batch_size: int,
              shuffle=False,
              small_data=False,
              train_name='train.json',
              dev_name='dev.json',
              test_name='test.json'):
     super().__init__()
     self.data_dir = data_dir
     print('loading dataset: ' + os.path.join(data_dir, train_name))
     self.train_dataset = self.read(os.path.join(data_dir, train_name))
     print('loading val dataset: ' + os.path.join(data_dir, dev_name))
     self.validation_dataset = self.read(os.path.join(data_dir, dev_name))
     self.vocab = Vocabulary.from_instances(self.train_dataset +
                                            self.validation_dataset)
     print('loading test dataset:' + os.path.join(data_dir, test_name))
     self.test_dataset = self.read(os.path.join(data_dir, test_name))
     self.batch_size = batch_size
     self.shuffle = shuffle
     self.small_data = small_data
     self.iterator = BasicIterator(batch_size=batch_size,
                                   cache_instances=True)
     self.iterator.index_with(self.vocab)
Ejemplo n.º 48
0
def _filter_data(data, vocab):
    def _is_correct_instance(batch):
        assert len(batch['words']['ru_bert']['offsets']) == 1

        if batch['words']['ru_bert']['token_ids'].shape[1] > 256:
            return False

        return all(
            begin <= end < batch['words']['ru_bert']['token_ids'].shape[1]
            for begin, end in batch['words']['ru_bert']['offsets'][0])

    iterator = BasicIterator(batch_size=1)
    iterator.index_with(vocab)

    result_data = []
    for instance in tqdm(data):
        batch = next(iterator([instance]))
        if _is_correct_instance(batch):
            result_data.append(instance)
        else:
            logger.info('Filtering out %s', batch['metadata'][0]['words'])

    logger.info('Removed %s samples', len(data) - len(result_data))
    return result_data
Ejemplo n.º 49
0
def main(serialization_directory, device):
    """
    serialization_directory : str, required.
        The directory containing the serialized weights.
    device: int, default = -1
        The device to run the evaluation on.
    """

    config = Params.from_file(os.path.join(serialization_directory, "model_params.json"))
    dataset_reader = DatasetReader.from_params(config['dataset_reader'])
    evaluation_data_path = config['validation_data_path']

    model = Model.load(config, serialization_dir=serialization_directory, cuda_device=device)

    prediction_file_path = os.path.join(serialization_directory, "predictions.txt")
    gold_file_path = os.path.join(serialization_directory, "gold.txt")
    prediction_file = open(prediction_file_path, "w+")
    gold_file = open(gold_file_path, "w+")

    # Load the evaluation data and index it.
    print("Reading evaluation data from {}".format(evaluation_data_path))
    dataset = dataset_reader.read(evaluation_data_path)
    dataset.index_instances(model._vocab)
    iterator = BasicIterator(batch_size=32)

    model_predictions = []
    for batch in tqdm.tqdm(iterator(dataset, num_epochs=1, shuffle=False)):
        tensor_batch = arrays_to_variables(batch, device, for_training=False)
        result = model(**tensor_batch)
        predictions = model.decode(result)
        model_predictions.extend(predictions["tags"])

    for instance, prediction in zip(dataset.instances, model_predictions):
        fields = instance.fields
        predicted_tags = [model._vocab.get_token_from_index(x, namespace="labels") for x in prediction]
        try:
            # Most sentences have a verbal predicate, but not all.
            verb_index = fields["verb_indicator"].labels.index(1)
        except ValueError:
            verb_index = None

        gold_tags = fields["tags"].labels
        sentence = fields["tokens"].tokens

        write_to_conll_eval_file(prediction_file, gold_file,
                                 verb_index, sentence, predicted_tags, gold_tags)
    prediction_file.close()
    gold_file.close()
Ejemplo n.º 50
0
 def __init__(self, dataset_reader : DatasetReader, data_iterator: DataIterator = None ,evaluation_command : BaseEvaluationCommand = None, model : Model = None, batch_size:int = 64):
     """
     Creates a predictor from an AMConllDatasetReader, optionally takes an AllenNLP model. The model can also be given later using set_model.
     If evaluation is required, en evaluation_command can be supplied as well.
     :param dataset_reader: an AMConllDatasetReader
     :param evaluation_command:
     :param model:
     """
     assert isinstance(dataset_reader, AMConllDatasetReader), "A predictor in the am-parser must take an AMConllDatasetReader"
     self.dataset_reader = dataset_reader
     self.model = model
     self.evaluation_command = evaluation_command
     self.batch_size = batch_size
     if data_iterator is None:
         self.data_iterator = BasicIterator()
     else:
         self.data_iterator = data_iterator
Ejemplo n.º 51
0
def evaluate_from_args(args: argparse.Namespace) -> Dict[str, Any]:
    # Disable some of the more verbose logging statements
    logging.getLogger('allennlp.common.params').disabled = True
    logging.getLogger('allennlp.nn.initializers').disabled = True
    logging.getLogger('allennlp.modules.token_embedders.embedding').setLevel(logging.INFO)

    # Load from archive
    archive = load_archive(args.archive_file, args.cuda_device, args.overrides)
    config = archive.config
    prepare_environment(config)
    model = archive.model
    model.eval()

    # Load the evaluation data
    dataset_reader = DatasetReader.from_params(config.pop('dataset_reader'))
    evaluation_data_path = args.evaluation_data_file
    logger.info("Reading evaluation data from %s", evaluation_data_path)
    dataset = dataset_reader.read(evaluation_data_path)
    dataset.index_instances(model.vocab)

    iterator = BasicIterator(batch_size=32)

    serialization_directory = args.archive_file[:-13]
    metrics = evaluate(model,
                       dataset,
                       iterator,
                       args.cuda_device,
                       serialization_directory)

    metrics_file_path = os.path.join(serialization_directory, "metrics.txt")
    metrics_file = open(metrics_file_path, "w+")

    logger.info("Finished evaluating.")
    logger.info("Metrics:")
    for key, metric in metrics.items():
        if "overall" in key:
            logger.info("%s: %s", key, metric)
        if "gold_spans" in key or "predicted_spans" in key:
            continue
        metrics_file.write("{}: {}\n".format(key, metric))
    logger.info("Detailed evaluation metrics in %s", metrics_file_path)
    return metrics
Ejemplo n.º 52
0
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
             "text_field_embedder": {
                     "tokens": {
                             "type": "embedding",
                             "embedding_dim": 5
                             }
                     },
             "encoder": {
                     "type": "lstm",
                     "input_size": 5,
                     "hidden_size": 7,
                     "num_layers": 2
                     }
             })
     self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
Ejemplo n.º 53
0
 def test_epoch_tracking_when_one_epoch_at_a_time(self):
     iterator = BasicIterator(batch_size=2, track_epoch=True)
     iterator.index_with(self.vocab)
     for epoch in range(10):
         for batch in iterator(self.instances, num_epochs=1):
             assert all(epoch_num == epoch for epoch_num in batch['epoch_num'])
def main(serialization_directory: int,
         device: int,
         data: str,
         prefix: str,
         domain: str = None):
    """
    serialization_directory : str, required.
        The directory containing the serialized weights.
    device: int, default = -1
        The device to run the evaluation on.
    data: str, default = None
        The data to evaluate on. By default, we use the validation data from
        the original experiment.
    prefix: str, default=""
        The prefix to prepend to the generated gold and prediction files, to distinguish
        different models/data.
    domain: str, optional (default = None)
        If passed, filters the ontonotes evaluation/test dataset to only contain the
        specified domain. This overwrites the domain in the config file from the model,
        to allow evaluation on domains other than the one the model was trained on.
    """
    config = Params.from_file(os.path.join(serialization_directory, "config.json"))

    if domain is not None:
        # Hack to allow evaluation on different domains than the
        # model was trained on.
        config["dataset_reader"]["domain_identifier"] = domain
        prefix = f"{domain}_{prefix}"
    else:
        config["dataset_reader"].pop("domain_identifier", None)

    dataset_reader = DatasetReader.from_params(config['dataset_reader'])
    evaluation_data_path = data if data else config['validation_data_path']

    archive = load_archive(os.path.join(serialization_directory, "model.tar.gz"), cuda_device=device)
    model = archive.model
    model.eval()

    prediction_file_path = os.path.join(serialization_directory, prefix + "_predictions.txt")
    gold_file_path = os.path.join(serialization_directory, prefix + "_gold.txt")
    prediction_file = open(prediction_file_path, "w+")
    gold_file = open(gold_file_path, "w+")

    # Load the evaluation data and index it.
    print("reading evaluation data from {}".format(evaluation_data_path))
    instances = dataset_reader.read(evaluation_data_path)

    with torch.autograd.no_grad():
        iterator = BasicIterator(batch_size=32)
        iterator.index_with(model.vocab)

        model_predictions = []
        batches = iterator(instances, num_epochs=1, shuffle=False, cuda_device=device)
        for batch in Tqdm.tqdm(batches):
            result = model(**batch)
            predictions = model.decode(result)
            model_predictions.extend(predictions["tags"])

        for instance, prediction in zip(instances, model_predictions):
            fields = instance.fields
            try:
                # Most sentences have a verbal predicate, but not all.
                verb_index = fields["verb_indicator"].labels.index(1)
            except ValueError:
                verb_index = None

            gold_tags = fields["tags"].labels
            sentence = [x.text for x in fields["tokens"].tokens]

            write_to_conll_eval_file(prediction_file, gold_file,
                                     verb_index, sentence, prediction, gold_tags)
        prediction_file.close()
        gold_file.close()
Ejemplo n.º 55
0
class TestTrainer(AllenNlpTestCase):
    def setUp(self):
        super(TestTrainer, self).setUp()
        self.instances = SequenceTaggingDatasetReader().read(self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
        vocab = Vocabulary.from_instances(self.instances)
        self.vocab = vocab
        self.model_params = Params({
                "text_field_embedder": {
                        "tokens": {
                                "type": "embedding",
                                "embedding_dim": 5
                                }
                        },
                "encoder": {
                        "type": "lstm",
                        "input_size": 5,
                        "hidden_size": 7,
                        "num_layers": 2
                        }
                })
        self.model = SimpleTagger.from_params(vocab=self.vocab, params=self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
        self.iterator = BasicIterator(batch_size=2)
        self.iterator.index_with(vocab)

    def test_trainer_can_run(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          validation_metric='+loss',
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
    def test_trainer_can_run_cuda(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=2,
                          cuda_device=0)
        trainer.train()

    @pytest.mark.skipif(torch.cuda.device_count() < 2,
                        reason="Need multiple GPUs.")
    def test_trainer_can_run_multiple_gpu(self):

        class MetaDataCheckWrapper(Model):
            """
            Checks that the metadata field has been correctly split across the batch dimension
            when running on multiple gpus.
            """
            def __init__(self, model):
                super().__init__(model.vocab)
                self.model = model

            def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore # pylint: disable=arguments-differ
                assert 'metadata' in kwargs and 'tags' in kwargs, \
                    f'tokens and metadata must be provided. Got {kwargs.keys()} instead.'
                batch_size = kwargs['tokens']['tokens'].size()[0]
                assert len(kwargs['metadata']) == batch_size, \
                    f'metadata must be split appropriately. Expected {batch_size} elements, ' \
                    f"got {len(kwargs['metadata'])} elements."
                return self.model.forward(**kwargs)

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = Trainer(MetaDataCheckWrapper(self.model), self.optimizer,
                          multigpu_iterator, self.instances, num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()

    def test_trainer_can_resume_training(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances,
                          validation_dataset=self.instances,
                          num_epochs=1, serialization_dir=self.TEST_DIR)
        trainer.train()
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR)

        epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 1
        assert len(val_metrics_per_epoch) == 1
        assert isinstance(val_metrics_per_epoch[0], float)
        assert val_metrics_per_epoch[0] != 0.
        new_trainer.train()

    def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_increasing_metric(
            self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="+test")
        # when it is the only metric it should be considered the best
        assert new_trainer._is_best_so_far(1, [])  # pylint: disable=protected-access
        # when it is the same as one before it it is not considered the best
        assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is the best it is considered the best
        assert new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is not the the best it is not considered the best
        assert not new_trainer._is_best_so_far(.0013, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_decreasing_metric(self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="-test")
        # when it is the only metric it should be considered the best
        assert new_trainer._is_best_so_far(1, [])  # pylint: disable=protected-access
        # when it is the same as one before it it is not considered the best
        assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is the best it is considered the best
        assert new_trainer._is_best_so_far(.013, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is not the the best it is not considered the best
        assert not new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_should_stop_early_with_increasing_metric(self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="+test")
        assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_should_stop_early_with_flat_lining_metric(self):
        flatline = [.2] * 6
        assert Trainer(self.model, self.optimizer,  # pylint: disable=protected-access
                       self.iterator, self.instances,
                       validation_dataset=self.instances,
                       num_epochs=3,
                       serialization_dir=self.TEST_DIR,
                       patience=5,
                       validation_metric="+test")._should_stop_early(flatline)  # pylint: disable=protected-access
        assert Trainer(self.model, self.optimizer,  # pylint: disable=protected-access
                       self.iterator, self.instances,
                       validation_dataset=self.instances,
                       num_epochs=3,
                       serialization_dir=self.TEST_DIR,
                       patience=5,
                       validation_metric="-test")._should_stop_early(flatline)  # pylint: disable=protected-access

    def test_should_stop_early_with_decreasing_metric(self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="-test")
        assert new_trainer._should_stop_early([.02, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .2, .1, .4, .5])  # pylint: disable=protected-access
        assert new_trainer._should_stop_early([.1, .3, .2, .1, .4, .5])  # pylint: disable=protected-access

    def test_should_stop_early_with_early_stopping_disabled(self):
        # Increasing metric
        trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances,
                          validation_dataset=self.instances, num_epochs=100,
                          patience=None, validation_metric="+test")
        decreasing_history = [float(i) for i in reversed(range(20))]
        assert not trainer._should_stop_early(decreasing_history)  # pylint: disable=protected-access

        # Decreasing metric
        trainer = Trainer(self.model, self.optimizer, self.iterator, self.instances,
                          validation_dataset=self.instances, num_epochs=100,
                          patience=None, validation_metric="-test")
        increasing_history = [float(i) for i in range(20)]
        assert not trainer._should_stop_early(increasing_history)  # pylint: disable=protected-access

    def test_should_stop_early_with_invalid_patience(self):
        for patience in [0, -1, -2, 1.5, 'None']:
            with pytest.raises(ConfigurationError,
                               message='No ConfigurationError for patience={}'.format(patience)):
                Trainer(self.model, self.optimizer, self.iterator, self.instances,
                        validation_dataset=self.instances, num_epochs=100,
                        patience=patience, validation_metric="+test")

    def test_trainer_can_run_with_lr_scheduler(self):
        lr_params = Params({"type": "reduce_on_plateau"})
        lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          validation_metric="-loss",
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        trainer.train()

    def test_trainer_can_resume_with_lr_scheduler(self):
        # pylint: disable=protected-access
        lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2, serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        new_trainer = Trainer(model=self.model,
                              optimizer=self.optimizer,
                              iterator=self.iterator,
                              learning_rate_scheduler=new_lr_scheduler,
                              train_dataset=self.instances,
                              validation_dataset=self.instances,
                              num_epochs=4, serialization_dir=self.TEST_DIR)
        epoch, _ = new_trainer._restore_checkpoint()
        assert epoch == 2
        assert new_trainer._learning_rate_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()

    def test_trainer_raises_on_model_with_no_loss_key(self):
        class FakeModel(torch.nn.Module):
            def forward(self, **kwargs):  # pylint: disable=arguments-differ,unused-argument
                return {}
        with pytest.raises(RuntimeError):
            trainer = Trainer(FakeModel(), self.optimizer,
                              self.iterator, self.instances,
                              num_epochs=2, serialization_dir=self.TEST_DIR)
            trainer.train()

    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          histogram_interval=2)
        trainer.train()

    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            assert sorted(epochs) == [2, 3, 4]

    def test_trainer_saves_metrics_every_epoch(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        for epoch in range(5):
            epoch_file = self.TEST_DIR / f'metrics_epoch_{epoch}.json'
            assert epoch_file.exists()
            metrics = json.load(open(epoch_file))
            assert "validation_loss" in metrics
            assert "best_validation_loss" in metrics
            assert metrics.get("epoch") == epoch

    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an iterator that sleeps for 2.5 second per epoch, so the total training
        #       time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.
        class WaitingIterator(BasicIterator):
            # pylint: disable=arguments-differ
            def _create_batches(self, *args, **kwargs):
                time.sleep(2.5)
                return super(WaitingIterator, self)._create_batches(*args, **kwargs)

        iterator = WaitingIterator(batch_size=2)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=6,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=2,
                          keep_serialized_model_every_num_seconds=5)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]

    def test_trainer_can_log_learning_rates_tensorboard(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          should_log_learning_rate=True,
                          summary_interval=2)

        trainer.train()

    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
                  for fname in file_names]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
            os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model, self.optimizer,
                                  self.iterator, self.instances, num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2  # pylint: disable=protected-access
Ejemplo n.º 56
0
class TestTrainer(AllenNlpTestCase):
    def setUp(self):
        super(TestTrainer, self).setUp()
        self.instances = SequenceTaggingDatasetReader().read('tests/fixtures/data/sequence_tagging.tsv')
        vocab = Vocabulary.from_instances(self.instances)
        self.vocab = vocab
        self.model_params = Params({
                "text_field_embedder": {
                        "tokens": {
                                "type": "embedding",
                                "embedding_dim": 5
                                }
                        },
                "encoder": {
                        "type": "lstm",
                        "input_size": 5,
                        "hidden_size": 7,
                        "num_layers": 2
                        }
                })
        self.model = SimpleTagger.from_params(self.vocab, self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
        self.iterator = BasicIterator(batch_size=2)
        self.iterator.index_with(vocab)

    def test_trainer_can_run(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          validation_metric='+loss',
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

    @pytest.mark.skipif(not torch.cuda.is_available(), reason="No CUDA device registered.")
    def test_trainer_can_run_cuda(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=2,
                          cuda_device=0)
        trainer.train()

    @pytest.mark.skipif(torch.cuda.device_count() < 2,
                        reason="Need multiple GPUs.")
    def test_trainer_can_run_multiple_gpu(self):
        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = Trainer(self.model, self.optimizer,
                          multigpu_iterator, self.instances, num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()

    def test_trainer_can_resume_training(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances,
                          validation_dataset=self.instances,
                          num_epochs=1, serialization_dir=self.TEST_DIR)
        trainer.train()
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR)

        epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 1
        assert len(val_metrics_per_epoch) == 1
        assert isinstance(val_metrics_per_epoch[0], float)
        assert val_metrics_per_epoch[0] != 0.
        new_trainer.train()

    def test_should_stop_early_with_increasing_metric(self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="+test")
        assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_should_stop_early_with_decreasing_metric(self):
        new_trainer = Trainer(self.model, self.optimizer,
                              self.iterator, self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3, serialization_dir=self.TEST_DIR,
                              patience=5, validation_metric="-test")
        assert new_trainer._should_stop_early([.02, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .2, .1, .4, .5])  # pylint: disable=protected-access

    def test_train_driver_raises_on_model_with_no_loss_key(self):

        class FakeModel(torch.nn.Module):
            def forward(self, **kwargs):  # pylint: disable=arguments-differ,unused-argument
                return {}
        with pytest.raises(ConfigurationError):
            trainer = Trainer(FakeModel(), self.optimizer,
                              self.iterator, self.instances,
                              num_epochs=2, serialization_dir=self.TEST_DIR)
            trainer.train()

    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          histogram_interval=2)
        trainer.train()

    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = Trainer(self.model, self.optimizer,
                          self.iterator, self.instances, num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            assert sorted(epochs) == [2, 3, 4]

    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an iterator that sleeps for 0.5 second per epoch, so the total training
        #       time for one epoch is slightly greater then 0.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 1 second.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.
        class WaitingIterator(BasicIterator):
            # pylint: disable=arguments-differ
            def _create_batches(self, *args, **kwargs):
                time.sleep(0.5)
                return super(WaitingIterator, self)._create_batches(*args, **kwargs)

        iterator = WaitingIterator(batch_size=2)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=6,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=2,
                          keep_serialized_model_every_num_seconds=1)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]

    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model, self.optimizer,
                          iterator, self.instances, num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
                  for fname in file_names]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(os.path.join(self.TEST_DIR, 'model_state_epoch_{}.th'.format(k)))
            os.remove(os.path.join(self.TEST_DIR, 'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model, self.optimizer,
                                  self.iterator, self.instances, num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint() # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2 # pylint: disable=protected-access