def resume( checkpoint: Union[str, Path], model: torch.nn.Module, reporter: Reporter, optimizers: Sequence[torch.optim.Optimizer], schedulers: Sequence[Optional[AbsScheduler]], scaler: Optional[GradScaler], ngpu: int = 0, ): states = torch.load( checkpoint, map_location=f"cuda:{torch.cuda.current_device()}" if ngpu > 0 else "cpu", ) model.load_state_dict(states["model"]) reporter.load_state_dict(states["reporter"]) for optimizer, state in zip(optimizers, states["optimizers"]): optimizer.load_state_dict(state) for scheduler, state in zip(schedulers, states["schedulers"]): if scheduler is not None: scheduler.load_state_dict(state) if scaler is not None: if states["scaler"] is None: logging.warning("scaler state is not found") else: scaler.load_state_dict(states["scaler"]) logging.info(f"The training was resumed using {checkpoint}")
def test_state_dict(): reporter = Reporter() reporter.set_epoch(1) with reporter.observe("train") as sub: stats1 = {"aa": 0.6} sub.register(stats1) with reporter.observe("eval") as sub: stats1 = {"bb": 0.6} sub.register(stats1) state = reporter.state_dict() reporter2 = Reporter() reporter2.load_state_dict(state) state2 = reporter2.state_dict() assert state == state2