Example #1
0
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
    # initial training to get a checkpoint
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
        limit_val_batches=0,
        progress_bar_refresh_rate=0,
        weights_summary=None,
    )
    trainer.fit(model)
    best_model_path = trainer.checkpoint_callback.best_model_path

    # resume from checkpoint with HookedModel
    called = []
    model = HookedModel(called)
    callback = HookedCallback(called)
    train_batches = 2
    trainer = Trainer(
        default_root_dir=tmpdir,
        # already performed 1 step, now resuming to do an additional 2
        max_steps=(1 + train_batches),
        limit_val_batches=0,
        progress_bar_refresh_rate=0,
        weights_summary=None,
        resume_from_checkpoint=best_model_path,
        callbacks=[callback])
    assert called == [
        dict(name='Callback.on_init_start', args=(trainer, )),
        dict(name='Callback.on_init_end', args=(trainer, )),
    ]
    trainer.fit(model)
    saved_ckpt = {
        'callbacks': ANY,
        'epoch': 2,  # TODO: wrong saved epoch
        'global_step': (1 + train_batches),
        'lr_schedulers': ANY,
        'optimizer_states': ANY,
        'pytorch-lightning_version': __version__,
        'state_dict': ANY,
    }
    expected = [
        dict(name='Callback.on_init_start', args=(trainer, )),
        dict(name='Callback.on_init_end', args=(trainer, )),
        dict(name='prepare_data'),
        dict(name='configure_callbacks'),
        dict(name='Callback.on_before_accelerator_backend_setup',
             args=(trainer, model)),
        dict(name='Callback.setup',
             args=(trainer, model),
             kwargs=dict(stage='fit')),
        dict(name='setup', kwargs=dict(stage='fit')),
        dict(name='on_load_checkpoint',
             args=({
                 'callbacks': ANY,
                 'epoch': 1,
                 'global_step': 1,
                 'lr_schedulers': ANY,
                 'optimizer_states': ANY,
                 'pytorch-lightning_version': __version__,
                 'state_dict': ANY,
             }, )),
        dict(name='configure_sharded_model'),
        dict(name='Callback.on_configure_sharded_model',
             args=(trainer, model)),
        dict(name='configure_optimizers'),
        dict(name='Callback.on_fit_start', args=(trainer, model)),
        dict(name='on_fit_start'),
        dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)),
        dict(name='on_pretrain_routine_start'),
        dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)),
        dict(name='on_pretrain_routine_end'),
        dict(name='train', args=(True, )),
        dict(name='on_train_dataloader'),
        dict(name='train_dataloader'),
        # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
        dict(name='on_val_dataloader'),
        dict(name='val_dataloader'),
        dict(name='Callback.on_train_start', args=(trainer, model)),
        dict(name='on_train_start'),
        dict(name='Callback.on_epoch_start', args=(trainer, model)),
        dict(name='on_epoch_start'),
        dict(name='Callback.on_train_epoch_start', args=(trainer, model)),
        dict(name='on_train_epoch_start'),
        # TODO: wrong current epoch after reload
        *model._train_batch(trainer, model, train_batches, current_epoch=1),
        dict(name='training_epoch_end',
             args=([dict(loss=ANY)] * train_batches, )),
        dict(name='Callback.on_train_epoch_end',
             args=(
                 trainer,
                 model,
                 [dict(loss=ANY)] * train_batches,
             )),
        dict(name='on_train_epoch_end',
             args=([dict(loss=ANY)] * train_batches, )),
        dict(name='Callback.on_epoch_end', args=(trainer, model)),
        dict(name='on_epoch_end'),
        dict(name='Callback.on_save_checkpoint',
             args=(trainer, model, saved_ckpt)),
        dict(name='on_save_checkpoint', args=(saved_ckpt, )),
        dict(name='Callback.on_train_end', args=(trainer, model)),
        dict(name='on_train_end'),
        dict(name='Callback.on_fit_end', args=(trainer, model)),
        dict(name='on_fit_end'),
        dict(name='Callback.teardown',
             args=(trainer, model),
             kwargs=dict(stage='fit')),
        dict(name='teardown', kwargs=dict(stage='fit')),
    ]
    assert called == expected
Example #2
0
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir):
    # initial training to get a checkpoint
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=1,
        limit_val_batches=0,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=[HookedCallback([])],
    )
    trainer.fit(model)
    best_model_path = trainer.checkpoint_callback.best_model_path

    # resume from checkpoint with HookedModel
    called = []
    model = HookedModel(called)
    callback = HookedCallback(called)

    # already performed 1 step, resume and do 2 more
    train_batches = 2
    steps_after_reload = 1 + train_batches
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_steps=steps_after_reload,
        limit_val_batches=0,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=[callback],
        track_grad_norm=1,
    )
    assert called == [
        dict(name="Callback.on_init_start", args=(trainer,)),
        dict(name="Callback.on_init_end", args=(trainer,)),
    ]

    trainer.fit(model, ckpt_path=best_model_path)
    loaded_ckpt = {
        "callbacks": ANY,
        "epoch": 1,  # TODO: wrong saved epoch, should be 0
        "global_step": 1,
        "lr_schedulers": ANY,
        "optimizer_states": ANY,
        "pytorch-lightning_version": __version__,
        "state_dict": ANY,
        "loops": ANY,
    }
    # TODO: wrong saved epoch, should be 0
    saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2}
    expected = [
        dict(name="Callback.on_init_start", args=(trainer,)),
        dict(name="Callback.on_init_end", args=(trainer,)),
        dict(name="configure_callbacks"),
        dict(name="prepare_data"),
        dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)),
        dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")),
        dict(name="setup", kwargs=dict(stage="fit")),
        dict(name="on_load_checkpoint", args=(loaded_ckpt,)),
        dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})),
        dict(name="configure_sharded_model"),
        dict(name="Callback.on_configure_sharded_model", args=(trainer, model)),
        dict(name="configure_optimizers"),
        dict(name="Callback.on_fit_start", args=(trainer, model)),
        dict(name="on_fit_start"),
        dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
        dict(name="on_pretrain_routine_start"),
        dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
        dict(name="on_pretrain_routine_end"),
        dict(name="train", args=(True,)),
        dict(name="on_train_dataloader"),
        dict(name="train_dataloader"),
        # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches`
        dict(name="on_val_dataloader"),
        dict(name="val_dataloader"),
        dict(name="Callback.on_train_start", args=(trainer, model)),
        dict(name="on_train_start"),
        dict(name="Callback.on_epoch_start", args=(trainer, model)),
        dict(name="on_epoch_start"),
        dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
        dict(name="on_train_epoch_start"),
        # TODO: wrong current epoch after reload
        *model._train_batch(trainer, model, train_batches, current_epoch=1),
        dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
        dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
        dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
        dict(name="on_save_checkpoint", args=(saved_ckpt,)),
        dict(name="on_train_epoch_end"),
        dict(name="Callback.on_epoch_end", args=(trainer, model)),
        dict(name="on_epoch_end"),
        dict(name="Callback.on_train_end", args=(trainer, model)),
        dict(name="on_train_end"),
        dict(name="Callback.on_fit_end", args=(trainer, model)),
        dict(name="on_fit_end"),
        dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")),
        dict(name="teardown", kwargs=dict(stage="fit")),
    ]
    assert called == expected
Example #3
0
def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir):
    # initial training to get a checkpoint
    model = BoringModel()
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=1,
        limit_train_batches=2,
        limit_val_batches=0,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=[HookedCallback([])],
    )
    trainer.fit(model)
    best_model_path = trainer.checkpoint_callback.best_model_path

    called = []
    callback = HookedCallback(called)
    # already performed 1 step, resume and do 2 more
    trainer = Trainer(
        default_root_dir=tmpdir,
        max_epochs=2,
        limit_train_batches=2,
        limit_val_batches=0,
        enable_progress_bar=False,
        enable_model_summary=False,
        callbacks=[callback],
        track_grad_norm=1,
    )
    assert called == [
        dict(name="Callback.on_init_start", args=(trainer, )),
        dict(name="Callback.on_init_end", args=(trainer, )),
    ]

    # resume from checkpoint with HookedModel
    model = HookedModel(called)
    trainer.fit(model, ckpt_path=best_model_path)
    loaded_ckpt = {
        "callbacks": ANY,
        "epoch": 0,
        "global_step": 2,
        "lr_schedulers": ANY,
        "optimizer_states": ANY,
        "pytorch-lightning_version": __version__,
        "state_dict": ANY,
        "loops": ANY,
    }
    saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1}
    expected = [
        dict(name="Callback.on_init_start", args=(trainer, )),
        dict(name="Callback.on_init_end", args=(trainer, )),
        dict(name="configure_callbacks"),
        dict(name="prepare_data"),
        dict(name="Callback.on_before_accelerator_backend_setup",
             args=(trainer, model)),
        dict(name="Callback.setup",
             args=(trainer, model),
             kwargs=dict(stage="fit")),
        dict(name="setup", kwargs=dict(stage="fit")),
        dict(name="on_load_checkpoint", args=(loaded_ckpt, )),
        dict(name="Callback.on_load_checkpoint",
             args=(trainer, model, {
                 "foo": True
             })),
        dict(name="Callback.load_state_dict", args=({
            "foo": True
        }, )),
        dict(name="configure_sharded_model"),
        dict(name="Callback.on_configure_sharded_model",
             args=(trainer, model)),
        dict(name="configure_optimizers"),
        dict(name="Callback.on_fit_start", args=(trainer, model)),
        dict(name="on_fit_start"),
        dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)),
        dict(name="on_pretrain_routine_start"),
        dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)),
        dict(name="on_pretrain_routine_end"),
        dict(name="train", args=(True, )),
        dict(name="train_dataloader"),
        dict(name="Callback.on_train_start", args=(trainer, model)),
        dict(name="on_train_start"),
        dict(name="Callback.on_epoch_start", args=(trainer, model)),
        dict(name="on_epoch_start"),
        dict(name="Callback.on_train_epoch_start", args=(trainer, model)),
        dict(name="on_train_epoch_start"),
        *model._train_batch(
            trainer, model, 2, current_epoch=1, current_batch=0),
        dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2, )),
        dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
        dict(name="Callback.state_dict"),
        dict(name="Callback.on_save_checkpoint",
             args=(trainer, model, saved_ckpt)),
        dict(name="on_save_checkpoint", args=(saved_ckpt, )),
        dict(name="on_train_epoch_end"),
        dict(name="Callback.on_epoch_end", args=(trainer, model)),
        dict(name="on_epoch_end"),
        dict(name="Callback.on_train_end", args=(trainer, model)),
        dict(name="on_train_end"),
        dict(name="Callback.on_fit_end", args=(trainer, model)),
        dict(name="on_fit_end"),
        dict(name="Callback.teardown",
             args=(trainer, model),
             kwargs=dict(stage="fit")),
        dict(name="teardown", kwargs=dict(stage="fit")),
    ]
    assert called == expected