def test_loops_state_dict(): trainer = Trainer() fit_loop = FitLoop() fit_loop.trainer = trainer state_dict = fit_loop.state_dict() new_fit_loop = FitLoop() new_fit_loop.trainer = trainer new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict()
def test_loops_state_dict(): trainer = Trainer() fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.trainer = object() fit_loop.trainer = trainer state_dict = fit_loop.state_dict() new_fit_loop = FitLoop() new_fit_loop.trainer = trainer new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict()
def test_fit_loop_done_log_messages(caplog): fit_loop = FitLoop() trainer = Mock(spec=Trainer) fit_loop.trainer = trainer trainer.should_stop = False trainer.num_training_batches = 5 assert not fit_loop.done assert not caplog.messages trainer.num_training_batches = 0 assert fit_loop.done assert "No training batches" in caplog.text caplog.clear() trainer.num_training_batches = 5 epoch_loop = Mock() epoch_loop.global_step = 10 fit_loop.connect(epoch_loop=epoch_loop) fit_loop.max_steps = 10 assert fit_loop.done assert "max_steps=10` reached" in caplog.text caplog.clear() fit_loop.max_steps = 20 fit_loop.epoch_progress.current.processed = 3 fit_loop.max_epochs = 3 trainer.should_stop = True assert fit_loop.done assert "max_epochs=3` reached" in caplog.text caplog.clear() fit_loop.max_epochs = 5 fit_loop.epoch_loop.min_steps = 0 with caplog.at_level(level=logging.DEBUG, logger="pytorch_lightning.utilities.rank_zero"): assert fit_loop.done assert "should_stop` was set" in caplog.text fit_loop.epoch_loop.min_steps = 100 assert not fit_loop.done assert "was signaled to stop but" in caplog.text