def load(self, name=None): name = name if name is None else str(name) self._before('load') self._before('load:parse_logs') with filelock.FileLock(self.ckpt_logs_lock_file): self._before('load:read_logs') self._ckpt_logs = self._get_ckpt_logs() self._before('load:read_logs') if name is None or (os.path.isdir(name) and get_relative_path(name) == self._ckpt_path): self._ckpt_file = os.path.join( self._ckpt_path, self._ckpt_logs['checkpoints'][-1]) else: self._ckpt_file = self._get_ckpt_file(name) if not get_relative_path( self._ckpt_file, self._ckpt_path) in self._ckpt_logs['checkpoints']: self._ckpt_file = get_relative_path(name) if not get_relative_path( self._ckpt_file, self._ckpt_path) in self._ckpt_logs['checkpoints']: raise FileNotFoundError( f'The checkpoint of "{name}" is not found') self._finish('load:parse_logs') self._before('load:read_data') result = io_registry_group.load(self._ckpt_file) self._finish('load:read_data') self._finish('load') return result
def save(self, name, data): self._before('save') self._current_name = str(name) if self._latest_saved_name != self._current_name: with filelock.FileLock(self.ckpt_logs_lock_file): self._before('save:read_logs') self._ckpt_logs = self._get_ckpt_logs() self._finish('save:read_logs') while len( self._ckpt_logs['checkpoints']) >= self._num_ckpts > 0: self._before('save:remove_ckpt') self._erased_ckpt_file = os.path.join( self._ckpt_path, self._ckpt_logs['checkpoints'].pop(0)) remove(self._erased_ckpt_file) self._finish('save:remove_ckpt') self._before('save:write_data') self._ckpt_file = self._get_ckpt_file(self._current_name) io_registry_group.dump(self._ckpt_file, data) self._finish('save:write_data') self._before('save:write_logs') self._ckpt_logs['checkpoints'].append( get_relative_path(self._ckpt_file, self._ckpt_path)) io_registry_group.dump(self.ckpt_logs_file, self._ckpt_logs) self._finish('save:write_logs') self._finish('save') self._latest_saved_name = self._current_name
def save(self, data, *, global_step, logger: logging.Logger): checkpoint = os.path.join(self.current_session_path, f"model_{global_step}.pth") logger.info( f"==> Saving session of global step {global_step} to {checkpoint}..." ) with filelock.FileLock( os.path.join(self.current_session_path, f"checkpoints.json.lock")): checkpoint_logs = self._load_checkpoint_logs( self.current_session_path) while len( checkpoint_logs['checkpoints']) >= self._num_sessions > 0: removed_checkpoint = checkpoint_logs['checkpoints'].pop(0) remove( os.path.join(self.current_session_path, removed_checkpoint)) logger.debug(f"Removed checkpoint {removed_checkpoint}") torch.save(data, checkpoint) checkpoint_logs['checkpoints'].append( get_relative_path(checkpoint, self.current_session_path)) self._save_checkpoint_logs(self.current_session_path, checkpoint_logs)
def get_parsed_path(cls, path): if not is_subdirectory(path): raise ValueError("Only path in current directory can be parsed") path = get_absolute_path(path) if os.path.isfile(path): dirname, filename = os.path.split(path) path = os.path.join(dirname, os.path.splitext(filename)[0]) path = get_relative_path(path) dirs = get_path_dirs(path) return dirs
def get_default_logger(name=None, logger_level="NOTSET", group_logger_key: str = None, group_handler_key='__default_console__', group: LogGroup = global_group, subscribe=True, format_kwargs=None, formatter_class: Type[logging.Formatter] = ColoredFormatter, formatter_kwargs=None, handler_level="NOTSET", handler_class=logging.StreamHandler, handler_kwargs=None, log_format=DefaultLogFormat) -> logging.Logger: if name is None: name = get_relative_path(get_caller_module().__file__) group_logger_key = group_logger_key or name with group.lock: logger = get_logger(name, logger_level, group_logger_key, group) handlers = group.lookup_handlers(group_handler_key) if not handlers: if issubclass(handler_class, logging.StreamHandler): default_handler_kwargs = {'stream': sys.stdout} else: default_handler_kwargs = dict() handler = make_handler( handler_class, level=handler_level, format_kwargs=format_kwargs, formatter_class=formatter_class, formatter_kwargs=formatter_kwargs, handler_kwargs=handler_kwargs or default_handler_kwargs, log_format=log_format ) group.bind_handler(group_handler_key, handler) group.add_loggers_handlers(group_logger_key, group_handler_key) if subscribe: group.subscribe_logger_handler(group_logger_key, group_handler_key) return logger
def __init__(self, ckpt_path, num_ckpts=5): super().__init__() self._register_hook_key('save') self._register_hook_key('save:read_logs') self._register_hook_key('save:remove_ckpt') self._register_hook_key('save:write_data') self._register_hook_key('save:write_logs') self._register_hook_key('load') self._register_hook_key('load:parse_logs') self._register_hook_key('load:read_logs') self._register_hook_key('load:read_data') self._ckpt_path = get_relative_path(ckpt_path) self._ckpt_logs = None self._ckpt_file = None self._current_name = None self._latest_saved_name = None self._erased_ckpt_file = None self._num_ckpts = num_ckpts