def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None, ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path # resume from checkpoint with HookedModel called = [] model = HookedModel(called) callback = HookedCallback(called) train_batches = 2 trainer = Trainer( default_root_dir=tmpdir, # already performed 1 step, now resuming to do an additional 2 max_steps=(1 + train_batches), limit_val_batches=0, progress_bar_refresh_rate=0, weights_summary=None, resume_from_checkpoint=best_model_path, callbacks=[callback]) assert called == [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), ] trainer.fit(model) saved_ckpt = { 'callbacks': ANY, 'epoch': 2, # TODO: wrong saved epoch 'global_step': (1 + train_batches), 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, } expected = [ dict(name='Callback.on_init_start', args=(trainer, )), dict(name='Callback.on_init_end', args=(trainer, )), dict(name='prepare_data'), dict(name='configure_callbacks'), dict(name='Callback.on_before_accelerator_backend_setup', args=(trainer, model)), dict(name='Callback.setup', args=(trainer, model), kwargs=dict(stage='fit')), dict(name='setup', kwargs=dict(stage='fit')), dict(name='on_load_checkpoint', args=({ 'callbacks': ANY, 'epoch': 1, 'global_step': 1, 'lr_schedulers': ANY, 'optimizer_states': ANY, 'pytorch-lightning_version': __version__, 'state_dict': ANY, }, )), dict(name='configure_sharded_model'), dict(name='Callback.on_configure_sharded_model', args=(trainer, model)), dict(name='configure_optimizers'), dict(name='Callback.on_fit_start', args=(trainer, model)), dict(name='on_fit_start'), dict(name='Callback.on_pretrain_routine_start', args=(trainer, model)), dict(name='on_pretrain_routine_start'), dict(name='Callback.on_pretrain_routine_end', args=(trainer, model)), dict(name='on_pretrain_routine_end'), dict(name='train', args=(True, )), dict(name='on_train_dataloader'), dict(name='train_dataloader'), # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` dict(name='on_val_dataloader'), dict(name='val_dataloader'), dict(name='Callback.on_train_start', args=(trainer, model)), dict(name='on_train_start'), dict(name='Callback.on_epoch_start', args=(trainer, model)), dict(name='on_epoch_start'), dict(name='Callback.on_train_epoch_start', args=(trainer, model)), dict(name='on_train_epoch_start'), # TODO: wrong current epoch after reload *model._train_batch(trainer, model, train_batches, current_epoch=1), dict(name='training_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_train_epoch_end', args=( trainer, model, [dict(loss=ANY)] * train_batches, )), dict(name='on_train_epoch_end', args=([dict(loss=ANY)] * train_batches, )), dict(name='Callback.on_epoch_end', args=(trainer, model)), dict(name='on_epoch_end'), dict(name='Callback.on_save_checkpoint', args=(trainer, model, saved_ckpt)), dict(name='on_save_checkpoint', args=(saved_ckpt, )), dict(name='Callback.on_train_end', args=(trainer, model)), dict(name='on_train_end'), dict(name='Callback.on_fit_end', args=(trainer, model)), dict(name='on_fit_end'), dict(name='Callback.teardown', args=(trainer, model), kwargs=dict(stage='fit')), dict(name='teardown', kwargs=dict(stage='fit')), ] assert called == expected
def test_trainer_model_hook_system_fit_no_val_and_resume(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_steps=1, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, callbacks=[HookedCallback([])], ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path # resume from checkpoint with HookedModel called = [] model = HookedModel(called) callback = HookedCallback(called) # already performed 1 step, resume and do 2 more train_batches = 2 steps_after_reload = 1 + train_batches trainer = Trainer( default_root_dir=tmpdir, max_steps=steps_after_reload, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, callbacks=[callback], track_grad_norm=1, ) assert called == [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), ] trainer.fit(model, ckpt_path=best_model_path) loaded_ckpt = { "callbacks": ANY, "epoch": 1, # TODO: wrong saved epoch, should be 0 "global_step": 1, "lr_schedulers": ANY, "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, "loops": ANY, } # TODO: wrong saved epoch, should be 0 saved_ckpt = {**loaded_ckpt, "global_step": steps_after_reload, "epoch": 2} expected = [ dict(name="Callback.on_init_start", args=(trainer,)), dict(name="Callback.on_init_end", args=(trainer,)), dict(name="configure_callbacks"), dict(name="prepare_data"), dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)), dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="setup", kwargs=dict(stage="fit")), dict(name="on_load_checkpoint", args=(loaded_ckpt,)), dict(name="Callback.on_load_checkpoint", args=(trainer, model, {"foo": True})), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), dict(name="on_pretrain_routine_start"), dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), dict(name="on_pretrain_routine_end"), dict(name="train", args=(True,)), dict(name="on_train_dataloader"), dict(name="train_dataloader"), # even though no validation runs, we initialize the val dataloader for properties like `num_val_batches` dict(name="on_val_dataloader"), dict(name="val_dataloader"), dict(name="Callback.on_train_start", args=(trainer, model)), dict(name="on_train_start"), dict(name="Callback.on_epoch_start", args=(trainer, model)), dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), # TODO: wrong current epoch after reload *model._train_batch(trainer, model, train_batches, current_epoch=1), dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt,)), dict(name="on_train_epoch_end"), dict(name="Callback.on_epoch_end", args=(trainer, model)), dict(name="on_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), dict(name="on_fit_end"), dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="teardown", kwargs=dict(stage="fit")), ] assert called == expected
def test_trainer_model_hook_system_fit_no_val_and_resume_max_epochs(tmpdir): # initial training to get a checkpoint model = BoringModel() trainer = Trainer( default_root_dir=tmpdir, max_epochs=1, limit_train_batches=2, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, callbacks=[HookedCallback([])], ) trainer.fit(model) best_model_path = trainer.checkpoint_callback.best_model_path called = [] callback = HookedCallback(called) # already performed 1 step, resume and do 2 more trainer = Trainer( default_root_dir=tmpdir, max_epochs=2, limit_train_batches=2, limit_val_batches=0, enable_progress_bar=False, enable_model_summary=False, callbacks=[callback], track_grad_norm=1, ) assert called == [ dict(name="Callback.on_init_start", args=(trainer, )), dict(name="Callback.on_init_end", args=(trainer, )), ] # resume from checkpoint with HookedModel model = HookedModel(called) trainer.fit(model, ckpt_path=best_model_path) loaded_ckpt = { "callbacks": ANY, "epoch": 0, "global_step": 2, "lr_schedulers": ANY, "optimizer_states": ANY, "pytorch-lightning_version": __version__, "state_dict": ANY, "loops": ANY, } saved_ckpt = {**loaded_ckpt, "global_step": 4, "epoch": 1} expected = [ dict(name="Callback.on_init_start", args=(trainer, )), dict(name="Callback.on_init_end", args=(trainer, )), dict(name="configure_callbacks"), dict(name="prepare_data"), dict(name="Callback.on_before_accelerator_backend_setup", args=(trainer, model)), dict(name="Callback.setup", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="setup", kwargs=dict(stage="fit")), dict(name="on_load_checkpoint", args=(loaded_ckpt, )), dict(name="Callback.on_load_checkpoint", args=(trainer, model, { "foo": True })), dict(name="Callback.load_state_dict", args=({ "foo": True }, )), dict(name="configure_sharded_model"), dict(name="Callback.on_configure_sharded_model", args=(trainer, model)), dict(name="configure_optimizers"), dict(name="Callback.on_fit_start", args=(trainer, model)), dict(name="on_fit_start"), dict(name="Callback.on_pretrain_routine_start", args=(trainer, model)), dict(name="on_pretrain_routine_start"), dict(name="Callback.on_pretrain_routine_end", args=(trainer, model)), dict(name="on_pretrain_routine_end"), dict(name="train", args=(True, )), dict(name="train_dataloader"), dict(name="Callback.on_train_start", args=(trainer, model)), dict(name="on_train_start"), dict(name="Callback.on_epoch_start", args=(trainer, model)), dict(name="on_epoch_start"), dict(name="Callback.on_train_epoch_start", args=(trainer, model)), dict(name="on_train_epoch_start"), *model._train_batch( trainer, model, 2, current_epoch=1, current_batch=0), dict(name="training_epoch_end", args=([dict(loss=ANY)] * 2, )), dict(name="Callback.on_train_epoch_end", args=(trainer, model)), dict(name="Callback.state_dict"), dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)), dict(name="on_save_checkpoint", args=(saved_ckpt, )), dict(name="on_train_epoch_end"), dict(name="Callback.on_epoch_end", args=(trainer, model)), dict(name="on_epoch_end"), dict(name="Callback.on_train_end", args=(trainer, model)), dict(name="on_train_end"), dict(name="Callback.on_fit_end", args=(trainer, model)), dict(name="on_fit_end"), dict(name="Callback.teardown", args=(trainer, model), kwargs=dict(stage="fit")), dict(name="teardown", kwargs=dict(stage="fit")), ] assert called == expected