Ejemplo n.º 1
0
def test_strict_model_load_more_params(monkeypatch, tmpdir, tmpdir_server,
                                       url_ckpt):
    """Tests use case where trainer saves the model, and user loads it from tags independently."""
    # set $TORCH_HOME, which determines torch hub's cache path, to tmpdir
    monkeypatch.setenv("TORCH_HOME", tmpdir)

    model = BoringModel()
    # Extra layer
    model.c_d3 = torch.nn.Linear(32, 32)

    # logger file to get meta
    logger = tutils.get_default_logger(tmpdir)

    # fit model
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=logger,
        callbacks=[ModelCheckpoint(dirpath=tmpdir)],
    )
    trainer.fit(model)

    # traning complete
    assert trainer.state.finished, f"Training failed with {trainer.state}"

    # save model
    new_weights_path = os.path.join(tmpdir, "save_test.ckpt")
    trainer.save_checkpoint(new_weights_path)

    # load new model
    hparams_path = os.path.join(tutils.get_data_path(logger, path_dir=tmpdir),
                                "hparams.yaml")
    hparams_url = f"http://{tmpdir_server[0]}:{tmpdir_server[1]}/{os.path.basename(new_weights_path)}"
    ckpt_path = hparams_url if url_ckpt else new_weights_path

    BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path,
                                     hparams_file=hparams_path,
                                     strict=False)

    with pytest.raises(
            RuntimeError,
            match=
            r'Unexpected key\(s\) in state_dict: "c_d3.weight", "c_d3.bias"'):
        BoringModel.load_from_checkpoint(checkpoint_path=ckpt_path,
                                         hparams_file=hparams_path,
                                         strict=True)
Ejemplo n.º 2
0
def _test_loggers_save_dir_and_weights_save_path(tmpdir, logger_class):
    class TestLogger(logger_class):
        # for this test it does not matter what these attributes are
        # so we standardize them to make testing easier
        @property
        def version(self):
            return "version"

        @property
        def name(self):
            return "name"

    model = BoringModel()
    trainer_args = dict(default_root_dir=tmpdir, max_steps=3)

    # no weights_save_path given
    save_dir = tmpdir / "logs"
    weights_save_path = None
    logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
    trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
    trainer.fit(model)
    assert trainer._weights_save_path_internal == trainer.default_root_dir
    assert trainer.checkpoint_callback.dirpath == os.path.join(str(logger.save_dir), "name", "version", "checkpoints")
    assert trainer.default_root_dir == tmpdir

    # with weights_save_path given, the logger path and checkpoint path should be different
    save_dir = tmpdir / "logs"
    weights_save_path = tmpdir / "weights"
    logger = TestLogger(**_get_logger_args(TestLogger, save_dir))
    with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"):
        trainer = Trainer(**trainer_args, logger=logger, weights_save_path=weights_save_path)
    trainer.fit(model)
    assert trainer._weights_save_path_internal == weights_save_path
    assert trainer.logger.save_dir == save_dir
    assert trainer.checkpoint_callback.dirpath == weights_save_path / "name" / "version" / "checkpoints"
    assert trainer.default_root_dir == tmpdir

    # no logger given
    weights_save_path = tmpdir / "weights"
    with pytest.deprecated_call(match=r"Setting `Trainer\(weights_save_path=\)` has been deprecated in v1.6"):
        trainer = Trainer(**trainer_args, logger=False, weights_save_path=weights_save_path)
    trainer.fit(model)
    assert trainer._weights_save_path_internal == weights_save_path
    assert trainer.checkpoint_callback.dirpath == weights_save_path / "checkpoints"
    assert trainer.default_root_dir == tmpdir
Ejemplo n.º 3
0
def test_correct_step_and_epoch(tmpdir):
    model = BoringModel()
    first_max_epochs = 2
    train_batches = 2
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=first_max_epochs,
                      limit_train_batches=train_batches,
                      limit_val_batches=0)
    assert trainer.current_epoch == 0
    assert trainer.global_step == 0

    trainer.fit(model)
    # TODO(@carmocca): should not need `-1`
    assert trainer.current_epoch == first_max_epochs - 1
    assert trainer.global_step == first_max_epochs * train_batches

    # save checkpoint after loop ends, training end called, epoch count increased
    ckpt_path = str(tmpdir / "model.ckpt")
    trainer.save_checkpoint(ckpt_path)

    ckpt = torch.load(ckpt_path)
    assert ckpt["epoch"] == first_max_epochs
    # TODO(@carmocca): should not need `+1`
    assert ckpt["global_step"] == first_max_epochs * train_batches + 1

    max_epochs = first_max_epochs + 2
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=max_epochs,
                      limit_train_batches=train_batches,
                      limit_val_batches=0)
    # the ckpt state is not loaded at this point
    assert trainer.current_epoch == 0
    assert trainer.global_step == 0

    class TestModel(BoringModel):
        def on_pretrain_routine_end(self) -> None:
            assert self.trainer.current_epoch == first_max_epochs
            # TODO(@carmocca): should not need `+1`
            assert self.trainer.global_step == first_max_epochs * train_batches + 1

    trainer.fit(TestModel(), ckpt_path=ckpt_path)
    # TODO(@carmocca): should not need `-1`
    assert trainer.current_epoch == max_epochs - 1
    # TODO(@carmocca): should not need `+1`
    assert trainer.global_step == max_epochs * train_batches + 1
Ejemplo n.º 4
0
def test_v1_5_0_old_on_test_epoch_end(tmpdir):
    callback_warning_cache.clear()

    class OldSignature(Callback):
        def on_test_epoch_end(self, trainer, pl_module):  # noqa
            ...

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=1,
                      callbacks=OldSignature())

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.test(model)

    class OldSignatureModel(BoringModel):
        def on_test_epoch_end(self):  # noqa
            ...

    model = OldSignatureModel()

    with pytest.deprecated_call(match="old signature will be removed in v1.5"):
        trainer.test(model)

    callback_warning_cache.clear()

    class NewSignature(Callback):
        def on_test_epoch_end(self, trainer, pl_module, outputs):
            ...

    trainer.callbacks = [NewSignature()]
    with no_deprecated_call(
            match="`Callback.on_test_epoch_end` signature has changed in v1.3."
    ):
        trainer.test(model)

    class NewSignatureModel(BoringModel):
        def on_test_epoch_end(self, outputs):
            ...

    model = NewSignatureModel()
    with no_deprecated_call(
            match=
            "`ModelHooks.on_test_epoch_end` signature has changed in v1.3."):
        trainer.test(model)
Ejemplo n.º 5
0
def test_min_max_steps_epochs(tmpdir, min_epochs, max_epochs, min_steps,
                              max_steps):
    """Tests that max_steps can be used without max_epochs."""
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        min_epochs=min_epochs,
        max_epochs=max_epochs,
        min_steps=min_steps,
        max_steps=max_steps,
        enable_model_summary=False,
    )
    trainer.fit(model)

    # check training stopped at max_epochs or max_steps
    if trainer.max_steps and not trainer.max_epochs:
        assert trainer.global_step == trainer.max_steps
def test_ddp_spawn_fp16_compress_comm_hook(tmpdir):
    """Test for DDP Spawn FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPSpawnPlugin(
        ddp_comm_hook=default.fp16_compress_hook,
        sync_batchnorm=True,
    )
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        plugins=[training_type_plugin],
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    assert (trainer.state == TrainerState.FINISHED
            ), f"Training failed with {trainer.state}"
Ejemplo n.º 7
0
def test_profiler_teardown(tmpdir, cls):
    """
    This test checks if profiler teardown method is called when trainer is exiting.
    """
    class TestCallback(Callback):
        def on_fit_end(self, trainer, pl_module) -> None:
            assert trainer.profiler.output_file is not None

    profiler = cls(output_filename=os.path.join(tmpdir, "profiler.txt"))

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      profiler=profiler,
                      callbacks=[TestCallback()])
    trainer.fit(model)

    assert profiler.output_file is None
Ejemplo n.º 8
0
def test_all_callback_states_saved_before_checkpoint_callback(tmpdir):
    """Test that all callback states get saved even if the ModelCheckpoint is not given as last."""

    callback0 = StatefulCallback0()
    callback1 = StatefulCallback1()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, filename="all_states")
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir, max_steps=1, limit_val_batches=1, callbacks=[callback0, checkpoint_callback, callback1]
    )
    trainer.fit(model)

    ckpt = torch.load(str(tmpdir / "all_states.ckpt"))
    state0 = ckpt["callbacks"]["StatefulCallback0"]
    state1 = ckpt["callbacks"]["StatefulCallback1"]
    assert "content0" in state0 and state0["content0"] == 0
    assert "content1" in state1 and state1["content1"] == 1
    assert "ModelCheckpoint" in ckpt["callbacks"]
def test_ddp_fp16_compress_comm_hook(tmpdir):
    """Test for DDP FP16 compress hook."""
    model = BoringModel()
    training_type_plugin = DDPPlugin(ddp_comm_hook=default.fp16_compress_hook)
    trainer = Trainer(
        max_epochs=1,
        gpus=2,
        strategy=training_type_plugin,
        default_root_dir=tmpdir,
        sync_batchnorm=True,
        fast_dev_run=True,
    )
    trainer.fit(model)
    trainer_comm_hook = trainer.accelerator.training_type_plugin._model.get_ddp_logging_data(
    ).comm_hook
    expected_comm_hook = default.fp16_compress_hook.__qualname__
    assert trainer_comm_hook == expected_comm_hook
    assert trainer.state.finished, f"Training failed with {trainer.state}"
Ejemplo n.º 10
0
def test_model_properties_resume_from_checkpoint(tmpdir):
    """Test that properties like `current_epoch` and `global_step` in model and trainer are always the same."""
    model = BoringModel()
    checkpoint_callback = ModelCheckpoint(dirpath=tmpdir, monitor="val_loss", save_last=True)
    trainer_args = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=2,
        logger=False,
        callbacks=[checkpoint_callback, ModelTrainerPropertyParity()],  # this performs the assertions
    )
    trainer = Trainer(**trainer_args)
    trainer.fit(model)

    trainer_args.update(max_epochs=2)
    trainer = Trainer(**trainer_args, resume_from_checkpoint=str(tmpdir / "last.ckpt"))
    trainer.fit(model)
Ejemplo n.º 11
0
def test_profiler_teardown(tmpdir, cls):
    """
    This test checks if profiler teardown method is called when trainer is exiting.
    """
    class TestCallback(Callback):
        def on_fit_end(self, trainer, *args, **kwargs) -> None:
            # describe sets it to None
            assert trainer.profiler._output_file is None

    profiler = cls(dirpath=tmpdir, filename="profiler")
    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir,
                      fast_dev_run=True,
                      profiler=profiler,
                      callbacks=[TestCallback()])
    trainer.fit(model)

    assert profiler._output_file is None
Ejemplo n.º 12
0
def test_timer_zero_duration_stop(tmpdir, interval):
    """ Test that the timer stops training immediately after the first check occurs. """
    model = BoringModel()
    duration = timedelta(0)
    timer = Timer(duration=duration, interval=interval)
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[timer],
    )
    trainer.fit(model)
    if interval == "step":
        # timer triggers stop on step end
        assert trainer.global_step == 1
        assert trainer.current_epoch == 0
    else:
        # timer triggers stop on epoch end
        assert trainer.global_step == len(trainer.train_dataloader)
        assert trainer.current_epoch == 0
Ejemplo n.º 13
0
def test_test_progress_bar_update_amount(tmpdir, test_batches, refresh_rate,
                                         test_deltas):
    """
    Test that test progress updates with the correct amount.
    """
    model = BoringModel()
    progress_bar = MockedUpdateProgressBars(refresh_rate=refresh_rate)
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_test_batches=test_batches,
        callbacks=[progress_bar],
        logger=False,
        checkpoint_callback=False,
    )
    trainer.test(model)
    progress_bar.test_progress_bar.update.assert_has_calls(
        [call(delta) for delta in test_deltas])
def test_v1_7_0_old_on_train_batch_end(tmpdir):
    class OldSignature(Callback):
        def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
            ...

    class OldSignatureModel(BoringModel):
        def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
            ...

    model = BoringModel()
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
    with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
        trainer.fit(model)

    model = OldSignatureModel()
    trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, callbacks=OldSignature(), fast_dev_run=True)
    with pytest.deprecated_call(match="`dataloader_idx` argument will be removed in v1.7."):
        trainer.fit(model)
Ejemplo n.º 15
0
def test_resume_callback_state_saved_by_type(tmpdir):
    """Test that a legacy checkpoint that didn't use a state key before can still be loaded."""
    model = BoringModel()
    callback = OldStatefulCallback(state=111)
    trainer = Trainer(default_root_dir=tmpdir,
                      max_steps=1,
                      callbacks=[callback])
    trainer.fit(model)
    ckpt_path = Path(trainer.checkpoint_callback.best_model_path)
    assert ckpt_path.exists()

    callback = OldStatefulCallback(state=222)
    trainer = Trainer(default_root_dir=tmpdir,
                      max_steps=2,
                      callbacks=[callback],
                      resume_from_checkpoint=ckpt_path)
    trainer.fit(model)
    assert callback.state == 111
Ejemplo n.º 16
0
def test_tpu_clip_grad_by_value(tmpdir):
    """Test if clip_gradients by value works on TPU. (It should not.)"""
    tutils.reset_seed()
    trainer_options = dict(default_root_dir=tmpdir,
                           progress_bar_refresh_rate=0,
                           max_epochs=4,
                           tpu_cores=1,
                           limit_train_batches=10,
                           limit_val_batches=10,
                           gradient_clip_val=0.5,
                           gradient_clip_algorithm='value')

    model = BoringModel()
    with pytest.raises(AssertionError):
        tpipes.run_model_test(trainer_options,
                              model,
                              on_gpu=False,
                              with_hpc=False)
Ejemplo n.º 17
0
def test_model_16bit_tpu_index(tmpdir, tpu_core):
    """Make sure model trains on TPU."""
    tutils.reset_seed()
    trainer_options = dict(
        default_root_dir=tmpdir,
        precision=16,
        progress_bar_refresh_rate=0,
        max_epochs=2,
        tpu_cores=[tpu_core],
        limit_train_batches=4,
        limit_val_batches=2,
    )

    model = BoringModel()
    tpipes.run_model_test(trainer_options, model, on_gpu=False)
    assert torch_xla._XLAC._xla_get_default_device() == f'xla:{tpu_core}'
    assert os.environ.get('XLA_USE_BF16') == str(
        1), "XLA_USE_BF16 was not set in environment variables"
def test_device_stats_monitor_no_logger(tmpdir):
    """Test DeviceStatsMonitor with no logger in Trainer."""

    model = BoringModel()
    device_stats = DeviceStatsMonitor()

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[device_stats],
        max_epochs=1,
        logger=False,
        enable_checkpointing=False,
        enable_progress_bar=False,
    )

    with pytest.raises(MisconfigurationException,
                       match="Trainer that has no logger."):
        trainer.fit(model)
def test_gpu_stats_monitor_no_logger(tmpdir):
    """
    Test GPUStatsMonitor with no logger in Trainer.
    """
    model = BoringModel()
    gpu_stats = GPUStatsMonitor()

    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[gpu_stats],
        max_epochs=1,
        gpus=1,
        logger=False,
    )

    with pytest.raises(MisconfigurationException,
                       match='Trainer that has no logger.'):
        trainer.fit(model)
Ejemplo n.º 20
0
def test_simple_profiler_with_nonexisting_log_dir(tmpdir):
    """Ensure the profiler dirpath defaults to `trainer.log_dir`and creates it when not present."""
    nonexisting_tmpdir = tmpdir / "nonexisting"

    profiler = SimpleProfiler(filename="profiler")
    assert profiler.dirpath is None

    model = BoringModel()
    trainer = Trainer(
        default_root_dir=nonexisting_tmpdir, max_epochs=1, limit_train_batches=1, limit_val_batches=1, profiler=profiler
    )
    trainer.fit(model)

    expected = nonexisting_tmpdir / "lightning_logs" / "version_0"
    assert expected.exists()
    assert trainer.log_dir == expected
    assert profiler.dirpath == trainer.log_dir
    assert expected.join("fit-profiler.txt").exists()
Ejemplo n.º 21
0
def test_pytorch_profiler_trainer_validate(tmpdir):
    """Ensure that the profiler can be given to the trainer and validate function are properly recorded. """
    pytorch_profiler = PyTorchProfiler(dirpath=tmpdir,
                                       filename="profile",
                                       schedule=None)
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_val_batches=2,
        profiler=pytorch_profiler,
    )
    trainer.validate(model)

    assert sum(e.name == 'validation_step'
               for e in pytorch_profiler.function_events)

    path = pytorch_profiler.dirpath / f"validate-{pytorch_profiler.filename}.txt"
    assert path.read_text("utf-8")
Ejemplo n.º 22
0
def test_lr_monitor_single_lr(tmpdir):
    """Test that learning rates are extracted and logged for single lr scheduler."""
    tutils.reset_seed()

    model = BoringModel()

    lr_monitor = LearningRateMonitor()
    trainer = Trainer(default_root_dir=tmpdir,
                      max_epochs=2,
                      limit_val_batches=0.1,
                      limit_train_batches=0.5,
                      callbacks=[lr_monitor])
    trainer.fit(model)

    assert lr_monitor.lrs, "No learning rates logged"
    assert all(v is None for v in lr_monitor.last_momentum_values.values()
               ), "Momentum should not be logged by default"
    assert len(lr_monitor.lrs) == len(trainer.lr_scheduler_configs)
    assert list(lr_monitor.lrs) == ["lr-SGD"]
def test_hpc_max_ckpt_version(tmpdir):
    """ Test that the CheckpointConnector is able to find the hpc checkpoint file with the highest version. """
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
    )
    trainer.fit(model)
    trainer.save_checkpoint(tmpdir / "hpc_ckpt.ckpt")
    trainer.save_checkpoint(tmpdir / "hpc_ckpt_0.ckpt")
    trainer.save_checkpoint(tmpdir / "hpc_ckpt_3.ckpt")
    trainer.save_checkpoint(tmpdir / "hpc_ckpt_33.ckpt")

    assert trainer.checkpoint_connector.hpc_resume_path == str(
        tmpdir / "hpc_ckpt_33.ckpt")
    assert trainer.checkpoint_connector.max_ckpt_version_in_folder(
        tmpdir) == 33
    assert trainer.checkpoint_connector.max_ckpt_version_in_folder(
        tmpdir / "not" / "existing") is None
Ejemplo n.º 24
0
def test_amp_without_apex(tmpdir):
    """Check that even with apex amp type without requesting precision=16 the amp backend is void."""
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        amp_backend='native',
    )
    assert trainer.amp_backend is None

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        amp_backend='apex',
    )
    assert trainer.amp_backend is None
    trainer.fit(model)
    assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
    assert trainer.dev_debugger.count_events('AMP') == 0
Ejemplo n.º 25
0
def test_eval_distributed_sampler_warning(tmpdir):
    """Test that a warning is raised when `DistributedSampler` is used with evaluation."""

    model = BoringModel()
    trainer = Trainer(strategy="ddp",
                      devices=2,
                      accelerator="cpu",
                      fast_dev_run=True)
    trainer._data_connector.attach_data(model)

    trainer.state.fn = TrainerFn.VALIDATING
    with pytest.warns(PossibleUserWarning,
                      match="multi-device settings use `DistributedSampler`"):
        trainer.reset_val_dataloader(model)

    trainer.state.fn = TrainerFn.TESTING
    with pytest.warns(PossibleUserWarning,
                      match="multi-device settings use `DistributedSampler`"):
        trainer.reset_test_dataloader(model)
Ejemplo n.º 26
0
def test_multi_gpu_model_ddp_spawn(tmpdir):
    tutils.set_random_main_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        strategy="ddp_spawn",
        enable_progress_bar=False,
    )

    model = BoringModel()

    tpipes.run_model_test(trainer_options, model)

    # test memory helper functions
    memory.get_memory_profile("min_max")
def test_accumulated_gradient_batches_with_resume_from_checkpoint(tmpdir):
    """
    This test validates that accumulated gradient is properly recomputed and reset on the trainer.
    """

    ckpt = ModelCheckpoint(dirpath=tmpdir, save_last=True)
    model = BoringModel()
    trainer_kwargs = dict(max_epochs=1,
                          accumulate_grad_batches={0: 2},
                          callbacks=ckpt,
                          limit_train_batches=1,
                          limit_val_batches=0)
    trainer = Trainer(**trainer_kwargs)
    trainer.fit(model)

    trainer_kwargs['max_epochs'] = 2
    trainer_kwargs['resume_from_checkpoint'] = ckpt.last_model_path
    trainer = Trainer(**trainer_kwargs)
    trainer.fit(model)
Ejemplo n.º 28
0
def test_multi_gpu_model_dp(tmpdir):
    tutils.set_random_master_port()

    trainer_options = dict(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=10,
        limit_val_batches=10,
        gpus=[0, 1],
        accelerator='dp',
        progress_bar_refresh_rate=0,
    )

    model = BoringModel()

    tpipes.run_model_test(trainer_options, model)

    # test memory helper functions
    memory.get_memory_profile('min_max')
Ejemplo n.º 29
0
def test_amp_without_apex(bwd_mock, tmpdir):
    """Check that even with apex amp type without requesting precision=16 the amp backend is void."""
    model = BoringModel()

    trainer = Trainer(
        default_root_dir=tmpdir,
        amp_backend='native',
    )
    assert trainer.amp_backend is None

    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        amp_backend='apex',
    )
    assert trainer.amp_backend is None
    trainer.fit(model)
    assert trainer.state.finished, f"Training failed with {trainer.state}"
    assert not bwd_mock.called
def test_prediction_writer_hook_call_intervals(tmpdir):
    """Test that the `write_on_batch_end` and `write_on_epoch_end` hooks get invoked based on the defined
    interval."""
    DummyPredictionWriter.write_on_batch_end = Mock()
    DummyPredictionWriter.write_on_epoch_end = Mock()

    dataloader = DataLoader(RandomDataset(32, 64))

    model = BoringModel()
    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    results = trainer.predict(model, dataloaders=dataloader)
    assert len(results) == 4
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch_and_epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 1

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("batch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 4
    assert cb.write_on_epoch_end.call_count == 0

    DummyPredictionWriter.write_on_batch_end.reset_mock()
    DummyPredictionWriter.write_on_epoch_end.reset_mock()

    cb = DummyPredictionWriter("epoch")
    trainer = Trainer(limit_predict_batches=4, callbacks=cb)
    trainer.predict(model, dataloaders=dataloader, return_predictions=False)
    assert cb.write_on_batch_end.call_count == 0
    assert cb.write_on_epoch_end.call_count == 1