def test_trainer_can_run(self):
        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  callbacks=self.default_callbacks(serialization_dir=None),
                                  num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)
        assert 'peak_cpu_memory_MB' in metrics

        # Making sure that both increasing and decreasing validation metrics work.
        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  callbacks=self.default_callbacks(validation_metric="+loss",
                                                                   serialization_dir=None),
                                  num_epochs=2)
        metrics = trainer.train()
        assert 'best_validation_loss' in metrics
        assert isinstance(metrics['best_validation_loss'], float)
        assert 'best_validation_accuracy' in metrics
        assert isinstance(metrics['best_validation_accuracy'], float)
        assert 'best_validation_accuracy3' in metrics
        assert isinstance(metrics['best_validation_accuracy3'], float)
        assert 'best_epoch' in metrics
        assert isinstance(metrics['best_epoch'], int)
        assert 'peak_cpu_memory_MB' in metrics
        assert isinstance(metrics['peak_cpu_memory_MB'], float)
        assert metrics['peak_cpu_memory_MB'] > 0
 def test_trainer_can_run_cuda(self):
     self.model.cuda()
     trainer = CallbackTrainer(self.model, self.optimizer,
                               num_epochs=2,
                               callbacks=self.default_callbacks(),
                               cuda_device=0)
     trainer.train()
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

        trainer = CallbackTrainer(model=self.model,
                                  training_data=self.instances,
                                  iterator=self.iterator,
                                  optimizer=self.optimizer,
                                  callbacks=callbacks,
                                  num_epochs=2, serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
                self.optimizer, Params({"type": "exponential", "gamma": 0.5}))
        callbacks = self.default_callbacks() + [UpdateLearningRate(new_lr_scheduler)]

        new_trainer = CallbackTrainer(model=self.model,
                                      training_data=self.instances,
                                      iterator=self.iterator,
                                      optimizer=self.optimizer,
                                      callbacks=callbacks,
                                      num_epochs=4, serialization_dir=self.TEST_DIR)
        new_trainer.handler.fire_event(Events.TRAINING_START)
        assert new_trainer.epoch_number == 2
        assert new_lr_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
    def test_restored_training_returns_best_epoch_metrics_even_if_no_better_epoch_is_found_after_restoring(self):
        # Instead of -loss, use +loss to assure 2nd epoch is considered worse.
        # Run 1 epoch of original training.
        original_trainer = CallbackTrainer(self.model,
                                           training_data=self.instances,
                                           iterator=self.iterator,
                                           optimizer=self.optimizer,
                                           callbacks=self.default_callbacks(validation_metric="+loss"),
                                           num_epochs=1, serialization_dir=self.TEST_DIR)
        training_metrics = original_trainer.train()

        # Run 1 epoch of restored training.
        restored_trainer = CallbackTrainer(
                self.model,
                training_data=self.instances,
                iterator=self.iterator,
                optimizer=self.optimizer,
                callbacks=self.default_callbacks(validation_metric="+loss"),
                num_epochs=2, serialization_dir=self.TEST_DIR)
        restored_metrics = restored_trainer.train()

        assert "best_validation_loss" in restored_metrics
        assert "best_validation_accuracy" in restored_metrics
        assert "best_validation_accuracy3" in restored_metrics
        assert "best_epoch" in restored_metrics

        # Epoch 2 validation loss should be lesser than that of Epoch 1
        assert training_metrics["best_validation_loss"] == restored_metrics["best_validation_loss"]
        assert training_metrics["best_epoch"] == 0
        assert training_metrics["validation_loss"] > restored_metrics["validation_loss"]
    def test_trainer_can_resume_training_for_exponential_moving_average(self):
        moving_average = ExponentialMovingAverage(self.model.named_parameters())
        callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)]

        trainer = CallbackTrainer(self.model,
                                  training_data=self.instances,
                                  iterator=self.iterator,
                                  optimizer=self.optimizer,
                                  num_epochs=1, serialization_dir=self.TEST_DIR,
                                  callbacks=callbacks)
        trainer.train()

        new_moving_average = ExponentialMovingAverage(self.model.named_parameters())
        new_callbacks = self.default_callbacks() + [UpdateMovingAverage(new_moving_average)]

        new_trainer = CallbackTrainer(self.model,
                                      training_data=self.instances,
                                      iterator=self.iterator,
                                      optimizer=self.optimizer,
                                      num_epochs=3, serialization_dir=self.TEST_DIR,
                                      callbacks=new_callbacks)

        new_trainer.handler.fire_event(Events.TRAINING_START)  # pylint: disable=protected-access
        assert new_trainer.epoch_number == 1

        tracker = trainer.metric_tracker  # pylint: disable=protected-access
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None  # pylint: disable=protected-access

        new_trainer.train()
    def test_trainer_can_run_and_resume_with_momentum_scheduler(self):
        scheduler = MomentumScheduler.from_params(
                self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}))
        callbacks = self.default_callbacks() + [UpdateMomentum(scheduler)]
        trainer = CallbackTrainer(model=self.model,
                                  training_data=self.instances,
                                  iterator=self.iterator,
                                  optimizer=self.optimizer,
                                  num_epochs=4,
                                  callbacks=callbacks,
                                  serialization_dir=self.TEST_DIR)
        trainer.train()

        new_scheduler = MomentumScheduler.from_params(
                self.optimizer, Params({"type": "inverted_triangular", "cool_down": 2, "warm_up": 2}))
        new_callbacks = self.default_callbacks() + [UpdateMomentum(new_scheduler)]
        new_trainer = CallbackTrainer(model=self.model,
                                      training_data=self.instances,
                                      iterator=self.iterator,
                                      optimizer=self.optimizer,
                                      num_epochs=6,
                                      callbacks=new_callbacks,
                                      serialization_dir=self.TEST_DIR)
        new_trainer.handler.fire_event(Events.TRAINING_START)
        assert new_trainer.epoch_number == 4
        assert new_scheduler.last_epoch == 3
        new_trainer.train()
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = "allennlp/tests/fixtures/data/wikitables/"
        offline_lf_directory = wikitables_dir + "action_space_walker_output/"
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            offline_logical_forms_directory=offline_lf_directory)
        instances = wikitables_reader.read(wikitables_dir +
                                           "sample_data.examples")
        archive_path = (self.FIXTURES_ROOT / "semantic_parsing" /
                        "wikitables" / "serialization" / "model.tar.gz")
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)

        trainer = CallbackTrainer(
            model,
            instances,
            multigpu_iterator,
            self.optimizer,
            num_epochs=2,
            cuda_device=[0, 1],
            callbacks=[GradientNormAndClip()],
        )
        trainer.train()
    def test_trainer_saves_and_loads_best_validation_metrics_correctly_2(self):
        # Use -loss and run 1 epoch of original-training, and one of restored-training
        # Run 1 epoch of original training.
        trainer = CallbackTrainer(
                self.model, self.optimizer,
                callbacks=self.default_callbacks(validation_metric="+loss"),
                num_epochs=1, serialization_dir=self.TEST_DIR)
        trainer.train()

        _ = trainer.handler.fire_event(Events.RESTORE_CHECKPOINT)
        best_epoch_1 = trainer.metric_tracker.best_epoch
        best_validation_metrics_epoch_1 = trainer.metric_tracker.best_epoch_metrics
        # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653}
        assert isinstance(best_validation_metrics_epoch_1, dict)
        assert "loss" in best_validation_metrics_epoch_1

        # Run 1 more epoch of restored training.
        restore_trainer = CallbackTrainer(
                self.model, self.optimizer,
                callbacks=self.default_callbacks(validation_metric="+loss"),
                num_epochs=2, serialization_dir=self.TEST_DIR)
        restore_trainer.train()
        _ = restore_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT)
        best_epoch_2 = restore_trainer.metric_tracker.best_epoch
        best_validation_metrics_epoch_2 = restore_trainer.metric_tracker.best_epoch_metrics

        # Because of using +loss, 2nd epoch won't be better than 1st. So best val metrics should be same.
        assert best_epoch_1 == best_epoch_2 == 0
        assert best_validation_metrics_epoch_2 == best_validation_metrics_epoch_1
    def test_trainer_can_resume_training(self):
        trainer = CallbackTrainer(self.model,
                                  self.optimizer,
                                  callbacks=self.default_callbacks(),
                                  num_epochs=1,
                                  serialization_dir=self.TEST_DIR)
        trainer.train()

        new_trainer = CallbackTrainer(self.model,
                                      self.optimizer,
                                      callbacks=self.default_callbacks(),
                                      num_epochs=3,
                                      serialization_dir=self.TEST_DIR)

        new_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT)

        assert new_trainer.epoch_number == 1

        tracker = new_trainer.metric_tracker

        assert tracker is not None
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None

        new_trainer.train()
    def test_trainer_can_resume_with_lr_scheduler(self):
        lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, Params({
                "type": "exponential",
                "gamma": 0.5
            }))
        callbacks = self.default_callbacks() + [LrsCallback(lr_scheduler)]

        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  callbacks=callbacks,
                                  num_epochs=2,
                                  serialization_dir=self.TEST_DIR)
        trainer.train()

        new_lr_scheduler = LearningRateScheduler.from_params(
            self.optimizer, Params({
                "type": "exponential",
                "gamma": 0.5
            }))
        callbacks = self.default_callbacks() + [LrsCallback(new_lr_scheduler)]

        new_trainer = CallbackTrainer(model=self.model,
                                      optimizer=self.optimizer,
                                      callbacks=callbacks,
                                      num_epochs=4,
                                      serialization_dir=self.TEST_DIR)
        new_trainer.handler.fire_event(Events.RESTORE_CHECKPOINT)
        assert new_trainer.epoch_number == 2
        assert new_lr_scheduler.lr_scheduler.last_epoch == 1
        new_trainer.train()
Exemple #11
0
    def test_trainer_saves_and_loads_best_validation_metrics_correctly_1(self):
        # Use -loss and run 1 epoch of original-training, and one of restored-training
        # Run 1 epoch of original training.
        trainer = CallbackTrainer(
                self.model,
                training_data=self.instances,
                iterator=self.iterator,
                optimizer=self.optimizer,
                callbacks=self.default_callbacks(),
                num_epochs=1, serialization_dir=self.TEST_DIR)
        trainer.train()
        trainer.handler.fire_event(Events.TRAINING_START)
        best_epoch_1 = trainer.metric_tracker.best_epoch
        best_validation_metrics_epoch_1 = trainer.metric_tracker.best_epoch_metrics
        # best_validation_metrics_epoch_1: {'accuracy': 0.75, 'accuracy3': 1.0, 'loss': 0.6243013441562653}
        assert isinstance(best_validation_metrics_epoch_1, dict)
        assert "loss" in best_validation_metrics_epoch_1

        # Run 1 epoch of restored training.
        restore_trainer = CallbackTrainer(
                self.model,
                training_data=self.instances,
                iterator=self.iterator,
                optimizer=self.optimizer,
                callbacks=self.default_callbacks(),
                num_epochs=2, serialization_dir=self.TEST_DIR)
        restore_trainer.train()
        restore_trainer.handler.fire_event(Events.TRAINING_START)
        best_epoch_2 = restore_trainer.metric_tracker.best_epoch
        best_validation_metrics_epoch_2 = restore_trainer.metric_tracker.best_epoch_metrics

        # Because of using -loss, 2nd epoch would be better than 1st. So best val metrics should not be same.
        assert best_epoch_1 == 0 and best_epoch_2 == 1
        assert best_validation_metrics_epoch_2 != best_validation_metrics_epoch_1
    def test_trainer_can_resume_training(self):
        trainer = CallbackTrainer(
            self.model,
            training_data=self.instances,
            iterator=self.iterator,
            optimizer=self.optimizer,
            callbacks=self.default_callbacks(),
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )
        trainer.train()

        new_trainer = CallbackTrainer(
            self.model,
            training_data=self.instances,
            iterator=self.iterator,
            optimizer=self.optimizer,
            callbacks=self.default_callbacks(),
            num_epochs=3,
            serialization_dir=self.TEST_DIR,
        )

        new_trainer.handler.fire_event(Events.TRAINING_START)

        assert new_trainer.epoch_number == 1

        tracker = new_trainer.metric_tracker

        assert tracker is not None
        assert tracker.is_best_so_far()
        assert tracker._best_so_far is not None

        new_trainer.train()
    def test_validation_metrics_consistent_with_and_without_tracking(self):
        default_callbacks = self.default_callbacks(serialization_dir=None)
        default_callbacks_without_tracking = [
            callback for callback in default_callbacks
            if not isinstance(callback, TrackMetrics)
        ]
        trainer1 = CallbackTrainer(
            copy.deepcopy(self.model),
            training_data=self.instances,
            iterator=self.iterator,
            optimizer=copy.deepcopy(self.optimizer),
            callbacks=default_callbacks_without_tracking,
            num_epochs=1,
            serialization_dir=None,
        )

        trainer1.train()

        trainer2 = CallbackTrainer(
            copy.deepcopy(self.model),
            training_data=self.instances,
            iterator=self.iterator,
            optimizer=copy.deepcopy(self.optimizer),
            callbacks=default_callbacks,
            num_epochs=1,
            serialization_dir=None,
        )

        trainer2.train()
        metrics1 = trainer1.val_metrics
        metrics2 = trainer2.val_metrics
        assert metrics1.keys() == metrics2.keys()
        for key in ["accuracy", "accuracy3", "loss"]:
            np.testing.assert_almost_equal(metrics1[key], metrics2[key])
    def test_training_metrics_consistent_with_and_without_validation(self):
        default_callbacks = self.default_callbacks(serialization_dir=None)
        default_callbacks_without_validation = [
            callback for callback in default_callbacks
            if not isinstance(callback, Validate)
        ]
        trainer1 = CallbackTrainer(
            copy.deepcopy(self.model),
            copy.deepcopy(self.optimizer),
            callbacks=default_callbacks_without_validation,
            num_epochs=1,
            serialization_dir=None)

        trainer1.train()

        trainer2 = CallbackTrainer(copy.deepcopy(self.model),
                                   copy.deepcopy(self.optimizer),
                                   callbacks=default_callbacks,
                                   num_epochs=1,
                                   serialization_dir=None)

        trainer2.train()
        metrics1 = trainer1.train_metrics
        metrics2 = trainer2.train_metrics
        assert metrics1.keys() == metrics2.keys()
        for key in ['accuracy', 'accuracy3', 'loss']:
            np.testing.assert_almost_equal(metrics1[key], metrics2[key])
    def test_production_rule_field_with_multiple_gpus(self):
        wikitables_dir = 'allennlp/tests/fixtures/data/wikitables/'
        offline_lf_directory = wikitables_dir + 'action_space_walker_output/'
        wikitables_reader = WikiTablesDatasetReader(
            tables_directory=wikitables_dir,
            offline_logical_forms_directory=offline_lf_directory)
        instances = wikitables_reader.read(wikitables_dir +
                                           'sample_data.examples')
        archive_path = self.FIXTURES_ROOT / 'semantic_parsing' / 'wikitables' / 'serialization' / 'model.tar.gz'
        model = load_archive(archive_path).model
        model.cuda()

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(model.vocab)

        trainer = CallbackTrainer(model,
                                  self.optimizer,
                                  num_epochs=2,
                                  cuda_device=[0, 1],
                                  callbacks=[
                                      GenerateTrainingBatches(
                                          instances, multigpu_iterator),
                                      TrainSupervised()
                                  ])
        trainer.train()
 def test_trainer_can_run_exponential_moving_average(self):
     moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999)
     callbacks = self.default_callbacks() + [ComputeMovingAverage(moving_average)]
     trainer = CallbackTrainer(model=self.model,
                               optimizer=self.optimizer,
                               num_epochs=2,
                               callbacks=callbacks)
     trainer.train()
 def test_trainer_raises_on_model_with_no_loss_key(self):
     class FakeModel(Model):
         def forward(self, **kwargs):  # pylint: disable=arguments-differ,unused-argument
             return {}
     with pytest.raises(RuntimeError):
         trainer = CallbackTrainer(FakeModel(None), self.optimizer,
                                   callbacks=self.default_callbacks(),
                                   num_epochs=2, serialization_dir=self.TEST_DIR)
         trainer.train()
 def test_trainer_can_run_ema_from_params(self):
     uma_params = Params({"moving_average": {"decay": 0.9999}})
     callbacks = self.default_callbacks() + [UpdateMovingAverage.from_params(uma_params, self.model)]
     trainer = CallbackTrainer(model=self.model,
                               training_data=self.instances,
                               iterator=self.iterator,
                               optimizer=self.optimizer,
                               num_epochs=2,
                               callbacks=callbacks)
     trainer.train()
 def test_trainer_can_run_cuda(self):
     self.model.cuda()
     trainer = CallbackTrainer(self.model,
                               training_data=self.instances,
                               iterator=self.iterator,
                               optimizer=self.optimizer,
                               num_epochs=2,
                               callbacks=self.default_callbacks(),
                               cuda_device=0)
     trainer.train()
    def test_trainer_can_run_with_lr_scheduler(self):
        lr_params = Params({"type": "reduce_on_plateau"})
        lr_scheduler = LearningRateScheduler.from_params(self.optimizer, lr_params)
        callbacks = self.default_callbacks() + [UpdateLearningRate(lr_scheduler)]

        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  callbacks=callbacks,
                                  num_epochs=2)
        trainer.train()
    def test_trainer_saves_models_at_specified_interval(self):
        iterator = BasicIterator(batch_size=4)
        iterator.index_with(self.vocab)

        trainer = CallbackTrainer(
            self.model,
            training_data=self.instances,
            iterator=iterator,
            optimizer=self.optimizer,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            callbacks=self.default_callbacks(model_save_interval=0.0001),
        )

        trainer.train()

        # Now check the serialized files for models saved during the epoch.
        prefix = "model_state_epoch_*"
        file_names = sorted(glob.glob(os.path.join(self.TEST_DIR, prefix)))
        epochs = [
            re.search(r"_([0-9\.\-]+)\.th", fname).group(1)
            for fname in file_names
        ]
        # We should have checkpoints at the end of each epoch and during each, e.g.
        # [0.timestamp, 0, 1.timestamp, 1]
        assert len(epochs) == 4
        assert epochs[3] == "1"
        assert "." in epochs[0]

        # Now make certain we can restore from timestamped checkpoint.
        # To do so, remove the checkpoint from the end of epoch 1&2, so
        # that we are forced to restore from the timestamped checkpoints.
        for k in range(2):
            os.remove(
                os.path.join(self.TEST_DIR,
                             "model_state_epoch_{}.th".format(k)))
            os.remove(
                os.path.join(self.TEST_DIR,
                             "training_state_epoch_{}.th".format(k)))
        os.remove(os.path.join(self.TEST_DIR, "best.th"))

        restore_trainer = CallbackTrainer(
            self.model,
            training_data=self.instances,
            iterator=iterator,
            optimizer=self.optimizer,
            num_epochs=2,
            serialization_dir=self.TEST_DIR,
            callbacks=self.default_callbacks(model_save_interval=0.0001),
        )
        restore_trainer.handler.fire_event(Events.TRAINING_START)
        assert restore_trainer.epoch_number == 2
        # One batch per epoch.
        assert restore_trainer.batch_num_total == 2
 def test_trainer_can_run_exponential_moving_average(self):
     moving_average = ExponentialMovingAverage(self.model.named_parameters(), decay=0.9999)
     callbacks = self.default_callbacks() + [UpdateMovingAverage(moving_average)]
     trainer = CallbackTrainer(
         model=self.model,
         training_data=self.instances,
         iterator=self.iterator,
         optimizer=self.optimizer,
         num_epochs=2,
         callbacks=callbacks,
     )
     trainer.train()
    def test_trainer_can_log_learning_rates_tensorboard(self):
        callbacks = [cb for cb in self.default_callbacks() if not isinstance(cb, LogToTensorboard)]
        # The lambda: None is unfortunate, but it will get replaced by the callback.
        tensorboard = TensorboardWriter(lambda: None, should_log_learning_rate=True, summary_interval=2)
        callbacks.append(LogToTensorboard(tensorboard))

        trainer = CallbackTrainer(self.model, self.optimizer,
                                  num_epochs=2,
                                  serialization_dir=self.TEST_DIR,
                                  callbacks=callbacks)

        trainer.train()
    def test_trainer_respects_num_serialized_models_to_keep(self):
        trainer = CallbackTrainer(self.model, self.optimizer,
                                  num_epochs=5,
                                  serialization_dir=self.TEST_DIR,
                                  callbacks=self.default_callbacks(max_checkpoints=3))
        trainer.train()

        # Now check the serialized files
        for prefix in ['model_state_epoch_*', 'training_state_epoch_*']:
            file_names = glob.glob(os.path.join(self.TEST_DIR, prefix))
            epochs = [int(re.search(r"_([0-9])\.th", fname).group(1))
                      for fname in file_names]
            assert sorted(epochs) == [2, 3, 4]
    def test_trainer_posts_to_url(self):
        url = 'http://slack.com?webhook=ewifjweoiwjef'
        responses.add(responses.POST, url)
        post_to_url = PostToUrl(url, message="only a test")
        callbacks = self.default_callbacks() + [post_to_url]
        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  num_epochs=2,
                                  callbacks=callbacks)
        trainer.train()

        assert len(responses.calls) == 1
        assert responses.calls[0].response.request.body == b'{"text": "only a test"}'
    def test_trainer_saves_metrics_every_epoch(self):
        trainer = CallbackTrainer(model=self.model,
                                  optimizer=self.optimizer,
                                  num_epochs=5,
                                  serialization_dir=self.TEST_DIR,
                                  callbacks=self.default_callbacks(max_checkpoints=3))
        trainer.train()

        for epoch in range(5):
            epoch_file = self.TEST_DIR / f'metrics_epoch_{epoch}.json'
            assert epoch_file.exists()
            metrics = json.load(open(epoch_file))
            assert "validation_loss" in metrics
            assert "best_validation_loss" in metrics
            assert metrics.get("epoch") == epoch
    def test_trainer_can_log_histograms(self):
        # enable activation logging
        for module in self.model.modules():
            module.should_log_activations = True

        callbacks = [cb for cb in self.default_callbacks() if not isinstance(cb, LogToTensorboard)]
        # The lambda: None is unfortunate, but it will get replaced by the callback.
        tensorboard = TensorboardWriter(lambda: None, histogram_interval=2)
        callbacks.append(LogToTensorboard(tensorboard))

        trainer = CallbackTrainer(self.model, self.optimizer,
                                  num_epochs=3,
                                  serialization_dir=self.TEST_DIR,
                                  callbacks=callbacks)
        trainer.train()
    def test_handle_errors(self):
        class ErrorTest(Callback):
            """
            A callback with three triggers
            * at BATCH_START, it raises a RuntimeError
            * at TRAINING_END, it sets a finished flag to True
            * at ERROR, it captures `trainer.exception`
            """

            def __init__(self) -> None:
                self.exc: Optional[Exception] = None
                self.finished_training = None

            @handle_event(Events.BATCH_START)
            def raise_exception(self, trainer):
                raise RuntimeError("problem starting batch")

            @handle_event(Events.TRAINING_END)
            def finish_training(self, trainer):
                self.finished_training = True

            @handle_event(Events.ERROR)
            def capture_error(self, trainer):
                self.exc = trainer.exception

        error_test = ErrorTest()
        callbacks = self.default_callbacks() + [error_test]

        original_trainer = CallbackTrainer(
            self.model,
            self.instances,
            self.iterator,
            self.optimizer,
            callbacks=callbacks,
            num_epochs=1,
            serialization_dir=self.TEST_DIR,
        )

        with pytest.raises(RuntimeError):

            original_trainer.train()

        # The callback should have captured the exception.
        assert error_test.exc is not None
        assert error_test.exc.args == ("problem starting batch",)

        # The "finished" flag should never have been set to True.
        assert not error_test.finished_training
    def test_trainer_raises_on_model_with_no_loss_key(self):
        class FakeModel(Model):
            def forward(self, **kwargs):
                return {}

        with pytest.raises(RuntimeError):
            trainer = CallbackTrainer(
                FakeModel(None),
                training_data=self.instances,
                iterator=self.iterator,
                optimizer=self.optimizer,
                callbacks=self.default_callbacks(),
                num_epochs=2,
                serialization_dir=self.TEST_DIR,
            )
            trainer.train()
    def test_trainer_can_run_multiple_gpu(self):
        self.model.cuda()
        class MetaDataCheckWrapper(Model):
            """
            Checks that the metadata field has been correctly split across the batch dimension
            when running on multiple gpus.
            """
            def __init__(self, model):
                super().__init__(model.vocab)
                self.model = model

            def forward(self, **kwargs) -> Dict[str, torch.Tensor]:  # type: ignore # pylint: disable=arguments-differ
                assert 'metadata' in kwargs and 'tags' in kwargs, \
                    f'tokens and metadata must be provided. Got {kwargs.keys()} instead.'
                batch_size = kwargs['tokens']['tokens'].size()[0]
                assert len(kwargs['metadata']) == batch_size, \
                    f'metadata must be split appropriately. Expected {batch_size} elements, ' \
                    f"got {len(kwargs['metadata'])} elements."
                return self.model.forward(**kwargs)

        multigpu_iterator = BasicIterator(batch_size=4)
        multigpu_iterator.index_with(self.vocab)
        trainer = CallbackTrainer(MetaDataCheckWrapper(self.model), self.optimizer,
                                  num_epochs=2,
                                  callbacks=self.default_callbacks(iterator=multigpu_iterator),
                                  cuda_device=[0, 1])
        metrics = trainer.train()
        assert 'peak_cpu_memory_MB' in metrics
        assert isinstance(metrics['peak_cpu_memory_MB'], float)
        assert metrics['peak_cpu_memory_MB'] > 0
        assert 'peak_gpu_0_memory_MB' in metrics
        assert isinstance(metrics['peak_gpu_0_memory_MB'], int)
        assert 'peak_gpu_1_memory_MB' in metrics
        assert isinstance(metrics['peak_gpu_1_memory_MB'], int)