Example #1
0
def test_trainer_callback_hook_system_validate(tmpdir):
    """Test the callback hook system for validate."""

    model = BoringModel()
    callback_mock = MagicMock()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=2,
        progress_bar_refresh_rate=0,
    )

    trainer.validate(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.on_before_accelerator_backend_setup(trainer, model),
        call.setup(trainer, model, 'validate'),
        call.on_configure_sharded_model(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_batch_start(trainer, model, ANY, 1, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.teardown(trainer, model, 'validate'),
    ]
Example #2
0
def test_trainer_callback_hook_system_fit(_, tmpdir):
    """Test the callback hook system for fit."""

    model = BoringModel()
    callback_mock = MagicMock()
    trainer = Trainer(
        default_root_dir=tmpdir,
        callbacks=[callback_mock],
        max_epochs=1,
        limit_val_batches=1,
        limit_train_batches=3,
        progress_bar_refresh_rate=0,
    )

    # check that only the to calls exists
    assert trainer.callbacks[0] == callback_mock
    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
    ]

    # fit model
    trainer.fit(model)

    assert callback_mock.method_calls == [
        call.on_init_start(trainer),
        call.on_init_end(trainer),
        call.on_before_accelerator_backend_setup(trainer, model),
        call.setup(trainer, model, 'fit'),
        call.on_configure_sharded_model(trainer, model),
        call.on_fit_start(trainer, model),
        call.on_pretrain_routine_start(trainer, model),
        call.on_pretrain_routine_end(trainer, model),
        call.on_sanity_check_start(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_sanity_check_end(trainer, model),
        call.on_train_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_train_epoch_start(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 0, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_batch_end(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 1, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 1, 0),
        call.on_batch_end(trainer, model),
        call.on_batch_start(trainer, model),
        call.on_train_batch_start(trainer, model, ANY, 2, 0),
        call.on_before_zero_grad(trainer, model, trainer.optimizers[0]),
        call.on_after_backward(trainer, model),
        call.on_train_batch_end(trainer, model, ANY, ANY, 2, 0),
        call.on_batch_end(trainer, model),
        call.on_train_epoch_end(trainer, model, ANY),
        call.on_epoch_end(trainer, model),
        call.on_validation_start(trainer, model),
        call.on_epoch_start(trainer, model),
        call.on_validation_epoch_start(trainer, model),
        call.on_validation_batch_start(trainer, model, ANY, 0, 0),
        call.on_validation_batch_end(trainer, model, ANY, ANY, 0, 0),
        call.on_validation_epoch_end(trainer, model),
        call.on_epoch_end(trainer, model),
        call.on_validation_end(trainer, model),
        call.on_save_checkpoint(
            trainer,
            model),  # should take ANY but we are inspecting signature for BC
        call.on_train_end(trainer, model),
        call.on_fit_end(trainer, model),
        call.teardown(trainer, model, 'fit'),
    ]