def test_loops_state_dict(): trainer = Trainer() trainer.train_dataloader = Mock() fit_loop = FitLoop() with pytest.raises(MisconfigurationException, match="Loop FitLoop should be connected to a"): fit_loop.trainer = object() fit_loop.trainer = trainer fit_loop.connect(Mock()) state_dict = fit_loop.state_dict() new_fit_loop = FitLoop() new_fit_loop.trainer = trainer new_fit_loop.load_state_dict(state_dict) assert fit_loop.state_dict() == new_fit_loop.state_dict()
def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() state_dict = trainer._checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { "state_dict": {}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.manual_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { "step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "zero_grad": { "total": {"ready": 0, "started": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "completed": 0}, }, }, "optimizer_position": 0, }, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { "total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { # number of batches across validation runs per epoch "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, # number of batches for this validation run "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, "epoch_loop.val_loop._results": { "batch": None, "batch_size": None, "dataloader_idx": None, "training": False, "device": None, "items": {}, }, "epoch_loop._results": { "batch": None, "batch_size": None, "dataloader_idx": None, "training": True, "device": None, "items": {}, }, "epoch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, "validate_loop": { "state_dict": {}, "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { # total batches run by `validate` "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, # number of batches run by this `validate` call "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, "_results": { "batch": None, "batch_size": None, "dataloader_idx": None, "training": False, "device": None, "items": {}, }, }, "test_loop": { "state_dict": {}, "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "is_last_batch": False, }, "_results": { "batch": None, "batch_size": None, "dataloader_idx": None, "training": False, "device": None, "items": {}, }, }, "predict_loop": { "state_dict": {}, "dataloader_progress": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, "current": {"ready": 0, "started": 0, "processed": 0, "completed": 0}, }, }, } assert state_dict == expected
def test_loops_state_dict_structure(): trainer = Trainer() trainer.train_dataloader = Mock() state_dict = trainer.checkpoint_connector._get_loops_state_dict() expected = { "fit_loop": { "state_dict": {}, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "is_last_batch": False, }, "epoch_loop.scheduler_progress": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 }, }, "epoch_loop.batch_loop.state_dict": {}, "epoch_loop.batch_loop.manual_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.state_dict": {}, "epoch_loop.batch_loop.optimizer_loop.optim_progress": { "optimizer": { "step": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 } }, "zero_grad": { "total": { "ready": 0, "started": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "completed": 0 }, }, }, "optimizer_position": 0, }, "epoch_loop.val_loop.state_dict": {}, "epoch_loop.val_loop.dataloader_progress": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 }, }, "epoch_loop.val_loop.epoch_loop.state_dict": {}, "epoch_loop.val_loop.epoch_loop.batch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "is_last_batch": False, }, "epoch_loop.val_loop._results": { "training": False, "_batch_size": torch.tensor(1), "device": None, "items": {}, }, "epoch_loop._results": { "training": True, "_batch_size": torch.tensor(1), "device": None, "items": {}, }, "epoch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, }, }, "validate_loop": { "state_dict": {}, "dataloader_progress": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 } }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "is_last_batch": False, }, "_results": { "training": False, "_batch_size": torch.tensor(1), "device": None, "items": {}, }, }, "test_loop": { "state_dict": {}, "dataloader_progress": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 } }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "is_last_batch": False, }, "_results": { "training": False, "_batch_size": torch.tensor(1), "device": None, "items": {}, }, }, "predict_loop": { "state_dict": {}, "dataloader_progress": { "total": { "ready": 0, "completed": 0 }, "current": { "ready": 0, "completed": 0 } }, "epoch_loop.state_dict": {}, "epoch_loop.batch_progress": { "total": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, "current": { "ready": 0, "started": 0, "processed": 0, "completed": 0 }, }, }, } assert state_dict == expected