def save_checkpoint(model, filename, meta=None): """Save checkpoint to file. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and ``optimizer``. By default ``meta`` will contain version and time info. Args: model (Module): Module whose params are to be saved. filename (str): Checkpoint filename. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved. meta (dict, optional): Metadata to be saved in checkpoint. """ if meta is None: meta = {} elif not isinstance(meta, dict): raise TypeError("meta must be a dict or None, but got {}".format( type(meta))) torchie.mkdir_or_exist(osp.dirname(filename)) if hasattr(model, "module"): model = model.module checkpoint = { "meta": meta, "state_dict": weights_to_cpu(model.state_dict()) } torch.save(checkpoint, filename)
def __init__( self, model, batch_processor, optimizer=None, lr_scheduler=None, work_dir=None, log_level=logging.INFO, logger=None, **kwargs, ): assert callable(batch_processor) self.model = model self.optimizer = optimizer self.lr_scheduler = lr_scheduler self.batch_processor = batch_processor # Create work_dir if torchie.is_str(work_dir): self.work_dir = osp.abspath(work_dir) torchie.mkdir_or_exist(self.work_dir) elif work_dir is None: self.work_dir = None else: raise TypeError("'work_dir' must be a str or None") # Get model name from the model class if hasattr(self.model, "module"): self._model_name = self.model.module.__class__.__name__ else: self._model_name = self.model.__class__.__name__ self._rank, self._world_size = get_dist_info() self.timestamp = get_time_str() if logger is None: self.logger = self.init_logger(work_dir, log_level) else: self.logger = logger self.log_buffer = LogBuffer() self.mode = None self._hooks = [] self._epoch = 0 self._iter = 0 self._inner_iter = 0 self._max_epochs = 0 self._max_iters = 0 self._example_stats = None