コード例 #1
0
 def test_few_instances_per_epoch(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, instances_per_epoch=3)
         # First epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2]]]
         # Second epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[3], self.instances[4]
         ], [self.instances[0]]]
         # Third epoch: 3 instances -> [2, 1]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[1], self.instances[2]
         ], [self.instances[3]]]
コード例 #2
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
 def setUp(self):
     super(TestTrainer, self).setUp()
     self.instances = SequenceTaggingDatasetReader().read(
         self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
     vocab = Vocabulary.from_instances(self.instances)
     self.vocab = vocab
     self.model_params = Params({
         "text_field_embedder": {
             "tokens": {
                 "type": "embedding",
                 "embedding_dim": 5
             }
         },
         "encoder": {
             "type": "lstm",
             "input_size": 5,
             "hidden_size": 7,
             "num_layers": 2
         }
     })
     self.model = SimpleTagger.from_params(vocab=self.vocab,
                                           params=self.model_params)
     self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
     self.iterator = BasicIterator(batch_size=2)
     self.iterator.index_with(vocab)
コード例 #3
0
 def test_can_optimise_model_with_dense_and_sparse_params(self):
     optimizer_params = Params({"type": "dense_sparse_adam"})
     parameters = [[n, p] for n, p in self.model.named_parameters()
                   if p.requires_grad]
     optimizer = Optimizer.from_params(parameters, optimizer_params)
     iterator = BasicIterator(2)
     iterator.index_with(self.vocab)
     Trainer(self.model, optimizer, iterator, self.instances).train()
コード例 #4
0
    def test_from_params(self):
        # pylint: disable=protected-access
        params = Params({})
        iterator = BasicIterator.from_params(params)
        assert iterator._batch_size == 32  # default value

        params = Params({"batch_size": 10})
        iterator = BasicIterator.from_params(params)
        assert iterator._batch_size == 10
コード例 #5
0
 def test_create_batches_groups_correctly(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2], self.instances[3]], [self.instances[4]]]
コード例 #6
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
 def test_trainer_can_run_multiple_gpu(self):
     multigpu_iterator = BasicIterator(batch_size=4)
     multigpu_iterator.index_with(self.vocab)
     trainer = Trainer(self.model,
                       self.optimizer,
                       multigpu_iterator,
                       self.instances,
                       num_epochs=2,
                       cuda_device=[0, 1])
     trainer.train()
コード例 #7
0
 def test_max_instances_in_memory(self):
     # pylint: disable=protected-access
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2, max_instances_in_memory=3)
         # One epoch: 5 instances -> [2, 1, 2]
         batches = list(
             iterator._create_batches(test_instances, shuffle=False))
         grouped_instances = [batch.instances for batch in batches]
         assert grouped_instances == [[
             self.instances[0], self.instances[1]
         ], [self.instances[2]], [self.instances[3], self.instances[4]]]
コード例 #8
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model,
                          self.optimizer,
                          iterator,
                          self.instances,
                          num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [
            re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
            for fname in file_names
        ]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(
                os.path.join(self.TEST_DIR,
                             'model_state_epoch_{}.th'.format(k)))
            os.remove(
                os.path.join(self.TEST_DIR,
                             'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model,
                                  self.optimizer,
                                  self.iterator,
                                  self.instances,
                                  num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2  # pylint: disable=protected-access
コード例 #9
0
 def test_get_num_batches(self):
     # Lazy and instances per epoch not specified.
     assert BasicIterator(batch_size=2).get_num_batches(
         self.lazy_instances) == 1
     # Lazy and instances per epoch specified.
     assert BasicIterator(batch_size=2,
                          instances_per_epoch=21).get_num_batches(
                              self.lazy_instances) == 11
     # Not lazy and instances per epoch specified.
     assert BasicIterator(batch_size=2,
                          instances_per_epoch=21).get_num_batches(
                              self.instances) == 11
     # Not lazy and instances per epoch not specified.
     assert BasicIterator(batch_size=2).get_num_batches(self.instances) == 3
コード例 #10
0
    def test_elmo_bilm(self):
        # get the raw data
        sentences, expected_lm_embeddings = self._load_sentences_embeddings()

        # load the test model
        elmo_bilm = _ElmoBiLm(self.options_file, self.weight_file)

        # Deal with the data.
        indexer = ELMoTokenCharactersIndexer()

        # For each sentence, first create a TextField, then create an instance
        instances = []
        for batch in zip(*sentences):
            for sentence in batch:
                tokens = [Token(token) for token in sentence.split()]
                field = TextField(tokens, {'character_ids': indexer})
                instance = Instance({"elmo": field})
                instances.append(instance)

        vocab = Vocabulary()

        # Now finally we can iterate through batches.
        iterator = BasicIterator(3)
        iterator.index_with(vocab)
        for i, batch in enumerate(
                iterator(instances, num_epochs=1, shuffle=False)):
            lm_embeddings = elmo_bilm(batch['elmo']['character_ids'])
            top_layer_embeddings, mask = remove_sentence_boundaries(
                lm_embeddings['activations'][2], lm_embeddings['mask'])

            # check the mask lengths
            lengths = mask.data.numpy().sum(axis=1)
            batch_sentences = [sentences[k][i] for k in range(3)]
            expected_lengths = [
                len(sentence.split()) for sentence in batch_sentences
            ]
            self.assertEqual(lengths.tolist(), expected_lengths)

            # get the expected embeddings and compare!
            expected_top_layer = [
                expected_lm_embeddings[k][i] for k in range(3)
            ]
            for k in range(3):
                self.assertTrue(
                    numpy.allclose(
                        top_layer_embeddings[k, :lengths[k], :].data.numpy(),
                        expected_top_layer[k],
                        atol=1.0e-6))
コード例 #11
0
    def test_maximum_samples_per_batch(self):
        for test_instances in (self.instances, self.lazy_instances):
            # pylint: disable=protected-access
            iterator = BasicIterator(
                batch_size=3, maximum_samples_per_batch=['num_tokens', 9])
            batches = list(
                iterator._create_batches(test_instances, shuffle=False))

            # ensure all instances are in a batch
            grouped_instances = [batch.instances for batch in batches]
            num_instances = sum(len(group) for group in grouped_instances)
            assert num_instances == len(self.instances)

            # ensure all batches are sufficiently small
            for batch in batches:
                batch_sequence_length = max([
                    instance.get_padding_lengths()['text']['num_tokens']
                    for instance in batch.instances
                ])
                assert batch_sequence_length * len(batch.instances) <= 9
コード例 #12
0
 def test_yield_one_epoch_iterates_over_the_data_once(self):
     for test_instances in (self.instances, self.lazy_instances):
         iterator = BasicIterator(batch_size=2)
         batches = list(iterator(test_instances, num_epochs=1))
         # We just want to get the single-token array for the text field in the instance.
         instances = [
             tuple(instance.detach().cpu().numpy()) for batch in batches
             for instance in batch['text']["tokens"]
         ]
         assert len(instances) == 5
         self.assert_instances_are_correct(instances)
コード例 #13
0
 def test_call_iterates_over_data_forever(self):
     for test_instances in (self.instances, self.lazy_instances):
         generator = BasicIterator(batch_size=2)(test_instances)
         batches = [next(generator)
                    for _ in range(18)]  # going over the data 6 times
         # We just want to get the single-token array for the text field in the instance.
         instances = [
             tuple(instance.detach().cpu().numpy()) for batch in batches
             for instance in batch['text']["tokens"]
         ]
         assert len(instances) == 5 * 6
         self.assert_instances_are_correct(instances)
コード例 #14
0
    def test_shuffle(self):
        # pylint: disable=protected-access
        for test_instances in (self.instances, self.lazy_instances):

            iterator = BasicIterator(batch_size=2, instances_per_epoch=100)

            in_order_batches = list(
                iterator._create_batches(test_instances, shuffle=False))
            shuffled_batches = list(
                iterator._create_batches(test_instances, shuffle=True))

            assert len(in_order_batches) == len(shuffled_batches)

            # With 100 instances, shuffling better change the order.
            assert in_order_batches != shuffled_batches

            # But not the counts of the instances.
            in_order_counts = Counter(instance for batch in in_order_batches
                                      for instance in batch)
            shuffled_counts = Counter(instance for batch in shuffled_batches
                                      for instance in batch)
            assert in_order_counts == shuffled_counts
コード例 #15
0
    def test_multiple_cursors(self):
        # pylint: disable=protected-access
        lazy_instances1 = _LazyInstances(lambda: (i for i in self.instances))
        lazy_instances2 = _LazyInstances(lambda: (i for i in self.instances))

        eager_instances1 = self.instances[:]
        eager_instances2 = self.instances[:]

        for instances1, instances2 in [(eager_instances1, eager_instances2),
                                       (lazy_instances1, lazy_instances2)]:
            iterator = BasicIterator(batch_size=1, instances_per_epoch=2)
            iterator.index_with(self.vocab)

            # First epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]],
                                         [self.instances[1]]]

            # First epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[0]],
                                         [self.instances[1]]]

            # Second epoch through dataset1
            batches = list(iterator._create_batches(instances1, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]],
                                         [self.instances[3]]]

            # Second epoch through dataset2
            batches = list(iterator._create_batches(instances2, shuffle=False))
            grouped_instances = [batch.instances for batch in batches]
            assert grouped_instances == [[self.instances[2]],
                                         [self.instances[3]]]
コード例 #16
0
    def test_regularization(self):
        penalty = self.model.get_regularization_penalty()
        assert penalty == 0

        iterator = BasicIterator(batch_size=32)
        trainer = Trainer(
            self.model,
            None,  # optimizer,
            iterator,
            self.instances)

        # You get a RuntimeError if you call `model.forward` twice on the same inputs.
        # The data and config are such that the whole dataset is one batch.
        training_batch = next(iterator(self.instances, num_epochs=1))
        validation_batch = next(iterator(self.instances, num_epochs=1))

        training_loss = trainer._batch_loss(training_batch,
                                            for_training=True).data
        validation_loss = trainer._batch_loss(validation_batch,
                                              for_training=False).data

        # Training loss should have the regularization penalty, but validation loss should not.
        assert (training_loss == validation_loss).all()
コード例 #17
0
ファイル: trainer_test.py プロジェクト: sanyu12/Bert_Attempt
class TestTrainer(AllenNlpTestCase):
    def setUp(self):
        super(TestTrainer, self).setUp()
        self.instances = SequenceTaggingDatasetReader().read(
            self.FIXTURES_ROOT / 'data' / 'sequence_tagging.tsv')
        vocab = Vocabulary.from_instances(self.instances)
        self.vocab = vocab
        self.model_params = Params({
            "text_field_embedder": {
                "tokens": {
                    "type": "embedding",
                    "embedding_dim": 5
                }
            },
            "encoder": {
                "type": "lstm",
                "input_size": 5,
                "hidden_size": 7,
                "num_layers": 2
            }
        })
        self.model = SimpleTagger.from_params(vocab=self.vocab,
                                              params=self.model_params)
        self.optimizer = torch.optim.SGD(self.model.parameters(), 0.01)
        self.iterator = BasicIterator(batch_size=2)
        self.iterator.index_with(vocab)

    def test_trainer_can_run(self):
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          validation_metric='+loss',
                          num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)

    @pytest.mark.skipif(not torch.cuda.is_available(),
                        reason="No CUDA device registered.")
    def test_trainer_can_run_cuda(self):
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          num_epochs=2,
                          cuda_device=0)
        trainer.train()

    @pytest.mark.skipif(torch.cuda.device_count() < 2,
                        reason="Need multiple GPUs.")
    def test_trainer_can_run_multiple_gpu(self):
        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = Trainer(self.model,
                          self.optimizer,
                          multigpu_iterator,
                          self.instances,
                          num_epochs=2,
                          cuda_device=[0, 1])
        trainer.train()

    def test_trainer_can_resume_training(self):
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=1,
                          serialization_dir=self.TEST_DIR)
        trainer.train()
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR)

        epoch, val_metrics_per_epoch = new_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 1
        assert len(val_metrics_per_epoch) == 1
        assert isinstance(val_metrics_per_epoch[0], float)
        assert val_metrics_per_epoch[0] != 0.
        new_trainer.train()

    def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_increasing_metric(
            self):
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR,
                              patience=5,
                              validation_metric="+test")
        # when it is the only metric it should be considered the best
        assert new_trainer._is_best_so_far(1, [])  # pylint: disable=protected-access
        # when it is the same as one before it it is not considered the best
        assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is the best it is considered the best
        assert new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is not the the best it is not considered the best
        assert not new_trainer._is_best_so_far(.0013, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_metric_only_considered_best_so_far_when_strictly_better_than_those_before_it_decreasing_metric(
            self):
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR,
                              patience=5,
                              validation_metric="-test")
        # when it is the only metric it should be considered the best
        assert new_trainer._is_best_so_far(1, [])  # pylint: disable=protected-access
        # when it is the same as one before it it is not considered the best
        assert not new_trainer._is_best_so_far(.3, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is the best it is considered the best
        assert new_trainer._is_best_so_far(.013, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access
        # when it is not the the best it is not considered the best
        assert not new_trainer._is_best_so_far(13.00, [.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_should_stop_early_with_increasing_metric(self):
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR,
                              patience=5,
                              validation_metric="+test")
        assert new_trainer._should_stop_early([.5, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .3, .2, .5, .1])  # pylint: disable=protected-access

    def test_should_stop_early_with_flat_lining_metric(self):
        flatline = [.2] * 6
        assert Trainer(
            self.model,
            self.optimizer,  # pylint: disable=protected-access
            self.iterator,
            self.instances,
            validation_dataset=self.instances,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            patience=5,
            validation_metric="+test")._should_stop_early(flatline)  # pylint: disable=protected-access
        assert Trainer(
            self.model,
            self.optimizer,  # pylint: disable=protected-access
            self.iterator,
            self.instances,
            validation_dataset=self.instances,
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
            patience=5,
            validation_metric="-test")._should_stop_early(flatline)  # pylint: disable=protected-access

    def test_should_stop_early_with_decreasing_metric(self):
        new_trainer = Trainer(self.model,
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              validation_dataset=self.instances,
                              num_epochs=3,
                              serialization_dir=self.TEST_DIR,
                              patience=5,
                              validation_metric="-test")
        assert new_trainer._should_stop_early([.02, .3, .2, .1, .4, .4])  # pylint: disable=protected-access
        assert not new_trainer._should_stop_early([.3, .3, .2, .1, .4, .5])  # pylint: disable=protected-access
        assert new_trainer._should_stop_early([.1, .3, .2, .1, .4, .5])  # pylint: disable=protected-access

    def test_should_stop_early_with_early_stopping_disabled(self):
        # Increasing metric
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=100,
                          patience=None,
                          validation_metric="+test")
        decreasing_history = [float(i) for i in reversed(range(20))]
        assert not trainer._should_stop_early(decreasing_history)  # pylint: disable=protected-access

        # Decreasing metric
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          validation_dataset=self.instances,
                          num_epochs=100,
                          patience=None,
                          validation_metric="-test")
        increasing_history = [float(i) for i in range(20)]
        assert not trainer._should_stop_early(increasing_history)  # pylint: disable=protected-access

    def test_should_stop_early_with_invalid_patience(self):
        for patience in [0, -1, -2, 1.5, 'None']:
            with pytest.raises(
                    ConfigurationError,
                    message='No ConfigurationError for patience={}'.format(
                        patience)):
                Trainer(self.model,
                        self.optimizer,
                        self.iterator,
                        self.instances,
                        validation_dataset=self.instances,
                        num_epochs=100,
                        patience=patience,
                        validation_metric="+test")

    def test_trainer_can_run_with_lr_scheduler(self):

        lr_params = Params({"type": "reduce_on_plateau"})
        lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, lr_params)
        trainer = Trainer(model=self.model,
                          optimizer=self.optimizer,
                          iterator=self.iterator,
                          learning_rate_scheduler=lr_scheduler,
                          validation_metric="-loss",
                          train_dataset=self.instances,
                          validation_dataset=self.instances,
                          num_epochs=2)
        trainer.train()

    def test_trainer_raises_on_model_with_no_loss_key(self):
        class FakeModel(torch.nn.Module):
            def forward(self, **kwargs):  # pylint: disable=arguments-differ,unused-argument
                return {}

        with pytest.raises(RuntimeError):
            trainer = Trainer(FakeModel(),
                              self.optimizer,
                              self.iterator,
                              self.instances,
                              num_epochs=2,
                              serialization_dir=self.TEST_DIR)
            trainer.train()

    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          num_epochs=3,
                          serialization_dir=self.TEST_DIR,
                          histogram_interval=2)
        trainer.train()

    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = Trainer(self.model,
                          self.optimizer,
                          self.iterator,
                          self.instances,
                          num_epochs=5,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=3)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            assert sorted(epochs) == [2, 3, 4]

    def test_trainer_respects_keep_serialized_model_every_num_seconds(self):
        # To test:
        #   Create an iterator that sleeps for 2.5 second per epoch, so the total training
        #       time for one epoch is slightly greater then 2.5 seconds.
        #   Run for 6 epochs, keeping the last 2 models, models also kept every 5 seconds.
        #   Check the resulting checkpoints.  Should then have models at epochs
        #       2, 4, plus the last two at 5 and 6.
        class WaitingIterator(BasicIterator):
            # pylint: disable=arguments-differ
            def _create_batches(self, *args, **kwargs):
                time.sleep(2.5)
                return super(WaitingIterator,
                             self)._create_batches(*args, **kwargs)

        iterator = WaitingIterator(batch_size=2)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model,
                          self.optimizer,
                          iterator,
                          self.instances,
                          num_epochs=6,
                          serialization_dir=self.TEST_DIR,
                          num_serialized_models_to_keep=2,
                          keep_serialized_model_every_num_seconds=5)
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [
                int(re.search(r"_([0-9])\.th", fname).group(1))
                for fname in file_names
            ]
            # epoch N has N-1 in file name
            assert sorted(epochs) == [1, 3, 4, 5]

    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = Trainer(self.model,
                          self.optimizer,
                          iterator,
                          self.instances,
                          num_epochs=2,
                          serialization_dir=self.TEST_DIR,
                          model_save_interval=0.0001)

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = 'model_state_epoch_*'
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [
            re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
            for fname in file_names
        ]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == '1'
        assert '.' in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(
                os.path.join(self.TEST_DIR,
                             'model_state_epoch_{}.th'.format(k)))
            os.remove(
                os.path.join(self.TEST_DIR,
                             'training_state_epoch_{}.th'.format(k)))
        os.remove(os.path.join(self.TEST_DIR, 'best.th'))

        restore_trainer = Trainer(self.model,
                                  self.optimizer,
                                  self.iterator,
                                  self.instances,
                                  num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  model_save_interval=0.0001)
        epoch, _ = restore_trainer._restore_checkpoint()  # pylint: disable=protected-access
        assert epoch == 2
        # One batch per epoch.
        assert restore_trainer._batch_num_total == 2  # pylint: disable=protected-access