コード例 #1
0
def test_callbacks_and_logger_not_called_with_fastdevrun(tmpdir, fast_dev_run):
    """Test that ModelCheckpoint, EarlyStopping and Logger are turned off with fast_dev_run."""
    class FastDevRunModel(BoringModel):
        def __init__(self):
            super().__init__()
            self.training_step_call_count = 0
            self.training_epoch_end_call_count = 0
            self.validation_step_call_count = 0
            self.validation_epoch_end_call_count = 0
            self.test_step_call_count = 0

        def training_step(self, batch, batch_idx):
            self.log("some_metric", torch.tensor(7.0))
            self.logger.experiment.dummy_log("some_distribution",
                                             torch.randn(7) + batch_idx)
            self.training_step_call_count += 1
            return super().training_step(batch, batch_idx)

        def training_epoch_end(self, outputs):
            self.training_epoch_end_call_count += 1
            super().training_epoch_end(outputs)

        def validation_step(self, batch, batch_idx):
            self.validation_step_call_count += 1
            return super().validation_step(batch, batch_idx)

        def validation_epoch_end(self, outputs):
            self.validation_epoch_end_call_count += 1
            super().validation_epoch_end(outputs)

        def test_step(self, batch, batch_idx):
            self.test_step_call_count += 1
            return super().test_step(batch, batch_idx)

    checkpoint_callback = ModelCheckpoint()
    checkpoint_callback.save_checkpoint = Mock()
    early_stopping_callback = EarlyStopping(monitor="foo")
    early_stopping_callback._evaluate_stopping_criteria = Mock()
    trainer_config = dict(
        default_root_dir=tmpdir,
        fast_dev_run=fast_dev_run,
        val_check_interval=2,
        logger=True,
        log_every_n_steps=1,
        callbacks=[checkpoint_callback, early_stopping_callback],
    )

    def _make_fast_dev_run_assertions(trainer, model):
        # check the call count for train/val/test step/epoch
        assert model.training_step_call_count == fast_dev_run
        assert model.training_epoch_end_call_count == 1
        assert model.validation_step_call_count == 0 if model.validation_step is None else fast_dev_run
        assert model.validation_epoch_end_call_count == 0 if model.validation_step is None else 1
        assert model.test_step_call_count == fast_dev_run

        # check trainer arguments
        assert trainer.max_steps == fast_dev_run
        assert trainer.num_sanity_val_steps == 0
        assert trainer.max_epochs == 1
        assert trainer.val_check_interval == 1.0
        assert trainer.check_val_every_n_epoch == 1

        # there should be no logger with fast_dev_run
        assert isinstance(trainer.logger, DummyLogger)

        # checkpoint callback should not have been called with fast_dev_run
        assert trainer.checkpoint_callback == checkpoint_callback
        checkpoint_callback.save_checkpoint.assert_not_called()
        assert not os.path.exists(checkpoint_callback.dirpath)

        # early stopping should not have been called with fast_dev_run
        assert trainer.early_stopping_callback == early_stopping_callback
        early_stopping_callback._evaluate_stopping_criteria.assert_not_called()

    train_val_step_model = FastDevRunModel()
    trainer = Trainer(**trainer_config)
    trainer.fit(train_val_step_model)
    trainer.test(train_val_step_model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    _make_fast_dev_run_assertions(trainer, train_val_step_model)

    # -----------------------
    # also called once with no val step
    # -----------------------
    train_step_only_model = FastDevRunModel()
    train_step_only_model.validation_step = None

    trainer = Trainer(**trainer_config)
    trainer.fit(train_step_only_model)
    trainer.test(train_step_only_model)

    assert trainer.state.finished, f"Training failed with {trainer.state}"
    _make_fast_dev_run_assertions(trainer, train_step_only_model)