def save_checkpoint( model: nn.Module = None, optimizer: optim.Optimizer = None, scheduler: sche._LRScheduler = None, amp=None, exp_name: str = "", current_epoch: int = 1, full_net_path: str = "", state_net_path: str = "", ): """ 保存完整参数模型(大)和状态参数模型(小) Args: model (nn.Module): model object optimizer (optim.Optimizer): optimizer object scheduler (sche._LRScheduler): scheduler object amp (): apex.amp exp_name (str): exp_name current_epoch (int): in the epoch, model **will** be trained full_net_path (str): the path for saving the full model parameters state_net_path (str): the path for saving the state dict. """ state_dict = { "arch": exp_name, "epoch": current_epoch, "net_state": model.state_dict(), "opti_state": optimizer.state_dict(), "sche_state": scheduler.state_dict(), "amp_state": amp.state_dict() if amp else None, } torch.save(state_dict, full_net_path) torch.save(model.state_dict(), state_net_path)
def snapshot(self, net: torch.nn.Module, opt: Optimizer, sched: _LRScheduler = None, epoch: int = None, subdir='.'): """ Writes a snapshot of the training, i.e. network weights, optimizer state and scheduler state to a file in the log directory. :param net: the neural network :param opt: the optimizer used :param sched: the learning rate scheduler used :param epoch: the current epoch :param subdir: if given, creates a subdirectory in the log directory. The data is written to a file in this subdirectory instead. :return: """ outfile = pt.join(self.dir, subdir, 'snapshot.pt') if not pt.exists(os.path.dirname(outfile)): os.makedirs(os.path.dirname(outfile)) torch.save( { 'net': net.state_dict(), 'opt': opt.state_dict(), 'sched': sched.state_dict(), 'epoch': epoch }, outfile) return outfile
def save(self, path_to_checkpoints_dir: str, step: int, optimizer: Optimizer, scheduler: _LRScheduler) -> str: path_to_checkpoint = os.path.join(path_to_checkpoints_dir, f'model-{step}.pth') checkpoint = { 'state_dict': self.state_dict(), 'step': step, 'optimizer_state_dict': optimizer.state_dict(), 'scheduler_state_dict': scheduler.state_dict() } torch.save(checkpoint, path_to_checkpoint) return path_to_checkpoint
def _better_lr_sched_repr(lr_sched: _LRScheduler) -> str: return ( lr_sched.__class__.__name__ + "(\n " + "\n ".join( f"{k}: {v}" for k, v in lr_sched.state_dict().items() if not k.startswith("_") ) + "\n)" )
def simulate_values( # type: ignore[override] cls, num_events: int, lr_scheduler: _LRScheduler, **kwargs: Any) -> List[List[int]]: """Method to simulate scheduled values during num_events events. Args: num_events (int): number of events during the simulation. lr_scheduler (subclass of `torch.optim.lr_scheduler._LRScheduler`): lr_scheduler object to wrap. Returns: list of pairs: [event_index, value] """ if not isinstance(lr_scheduler, _LRScheduler): raise TypeError( "Argument lr_scheduler should be a subclass of torch.optim.lr_scheduler._LRScheduler, " f"but given {type(lr_scheduler)}") # This scheduler uses `torch.optim.lr_scheduler._LRScheduler` which # should be replicated in order to simulate LR values and # not perturb original scheduler. with tempfile.TemporaryDirectory() as tmpdirname: cache_filepath = Path(tmpdirname) / "ignite_lr_scheduler_cache.pt" obj = { "lr_scheduler": lr_scheduler.state_dict(), "optimizer": lr_scheduler.optimizer.state_dict( ), # type: ignore[attr-defined] } torch.save(obj, cache_filepath.as_posix()) values = [] scheduler = cls(save_history=False, lr_scheduler=lr_scheduler, **kwargs) # type: ignore[call-arg] for i in range(num_events): params = [ p[scheduler.param_name] for p in scheduler.optimizer_param_groups ] values.append([i] + params) scheduler(engine=None) obj = torch.load(cache_filepath.as_posix()) lr_scheduler.load_state_dict(obj["lr_scheduler"]) lr_scheduler.optimizer.load_state_dict( obj["optimizer"]) # type: ignore[attr-defined] return values
def collect_state_dict( self, iteration: Union[float, int], model: EmmentalModel, optimizer: Optimizer, lr_scheduler: _LRScheduler, metric_dict: Dict[str, float], ) -> Dict[str, Any]: r"""Collect the state dict of the model. Args: iteration(float or int): The current iteration. model(EmmentalModel): The model to checkpoint. optimizer(Optimizer): The optimizer used during training process. lr_scheduler(_LRScheduler): Learning rate scheduler. metric_dict(dict): the metric dict. Returns: dict: The state dict. """ model_params = { "name": model.name, "module_pool": model.collect_state_dict(), # "task_names": model.task_names, # "task_flows": model.task_flows, # "loss_funcs": model.loss_funcs, # "output_funcs": model.output_funcs, # "scorers": model.scorers, } state_dict = { "iteration": iteration, "model": model_params, "optimizer": optimizer.state_dict(), "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None, "metric_dict": metric_dict, } return state_dict
def fit_support( self, model, tasks: List[Task], dataloader: DataLoader, optimizer: Optimizer, scheduler: _LRScheduler, training_logger: ResultLogger, ): support_loss = 1.0 support_epoch = 0 # Don't change default optimizer and scheduler states optimizer_state_dict = deepcopy(optimizer.state_dict()) scheduler_state_dict = deepcopy(scheduler.state_dict()) # Reset tasks states for task in tasks: task.reset() model.freeze_weights() while (support_loss > self.support_min_loss and support_epoch < self.support_max_epochs): support_epoch += 1 support_loss = self.fit_one( model, tasks, dataloader, optimizer, scheduler, training_logger.epoch(support_epoch, self.support_max_epochs), train_model=False, ) optimizer.load_state_dict(optimizer_state_dict) scheduler.load_state_dict(scheduler_state_dict) model.defreeze_weights()
def checkpoint( self, iteration: Union[float, int], model: EmmentalModel, optimizer: Optimizer, lr_scheduler: _LRScheduler, metric_dict: Dict[str, float], ) -> None: """Checkpointing the checkpoint. Args: iteration: The current iteration. model: The model to checkpoint. optimizer: The optimizer used during training process. lr_scheduler: Learning rate scheduler. metric_dict: The metric dict. """ # Check the checkpoint_runway condition is met if iteration < self.checkpoint_runway: return elif not self.checkpoint_condition_met and iteration >= self.checkpoint_runway: self.checkpoint_condition_met = True logger.info( "checkpoint_runway condition has been met. Start checkpoining." ) # Save model state model_path = f"{self.checkpoint_path}/checkpoint_{iteration}.model.pth" model.save(model_path, verbose=False) logger.info(f"Save checkpoint of {iteration} {self.checkpoint_unit} " f"at {model_path}.") # Save optimizer state optimizer_path = f"{self.checkpoint_path}/checkpoint_{iteration}.optimizer.pth" optimizer_dict = { "optimizer": optimizer.state_dict(), } torch.save(optimizer_dict, optimizer_path) # Save lr_scheduler state scheduler_path = f"{self.checkpoint_path}/checkpoint_{iteration}.scheduler.pth" scheduler_dict = { "lr_scheduler": lr_scheduler.state_dict() if lr_scheduler else None } torch.save(scheduler_dict, scheduler_path) if self.checkpoint_all is False: for path in self.checkpoint_paths: if os.path.exists(path): os.remove(path) self.checkpoint_paths.extend( [model_path, optimizer_path, scheduler_path]) if not set(self.checkpoint_all_metrics.keys()).isdisjoint( set(metric_dict.keys())): new_best_metrics = self.is_new_best(metric_dict) for metric in new_best_metrics: best_metric_model_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.model.pth") copyfile( model_path, best_metric_model_path, ) logger.info( f"Save best model of metric {metric} to {best_metric_model_path}" ) best_metric_optimizer_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.optimizer.pth") copyfile(optimizer_path, best_metric_optimizer_path) best_metric_scheduler_path = ( f"{self.checkpoint_path}/best_model_" f"{metric.replace('/', '_')}.scheduler.pth") copyfile(scheduler_path, best_metric_scheduler_path)