def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer, iteration: int, optimizer: Optimizer = None, save_as_best: bool = False, extra: dict = None, include_iteration: int = True, name: str = 'model'): extra_state = dict(iteration=iteration) if optimizer: extra_state['optimizer'] = optimizer.state_dict() if extra: extra_state.update(extra) if save_as_best: dir_path = os.path.join(save_path, '%s_best' % name) else: dir_name = '%s_%s' % (name, iteration) if include_iteration else name dir_path = os.path.join(save_path, dir_name) util.create_directories_dir(dir_path) # save model if isinstance(model, DataParallel): model.module.save_pretrained(dir_path) else: model.save_pretrained(dir_path) # save vocabulary tokenizer.save_pretrained(dir_path) # save extra state_path = os.path.join(dir_path, 'extra.state') torch.save(extra_state, state_path)
def __init__(self, args: argparse.Namespace): self._args = args self._debug = self._args.debug run_key = str(datetime.datetime.now()).replace(' ', '_') print("*****", "***", self._args) if hasattr(args, 'save_path'): self._save_path = os.path.join( self._args.save_path, self._args.label, str(self._args.rel_filter_threshold), str(self._args.epochs)) util.create_directories_dir(self._save_path) # logging if hasattr(args, 'log_path'): self._log_path = os.path.join(self._args.log_path, self._args.label, str(self._args.rel_filter_threshold), str(self._args.epochs)) util.create_directories_dir(self._log_path) self._log_paths = dict() # file + console logging log_formatter = logging.Formatter( "%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s] %(message)s" ) self._logger = logging.getLogger() util.reset_logger(self._logger) file_handler = logging.FileHandler( os.path.join(self._log_path, 'all.log')) file_handler.setFormatter(log_formatter) self._logger.addHandler(file_handler) console_handler = logging.StreamHandler(sys.stdout) console_handler.setFormatter(log_formatter) self._logger.addHandler(console_handler) if self._debug: self._logger.setLevel(logging.DEBUG) else: self._logger.setLevel(logging.INFO) # tensorboard summary self._summary_writer = tensorboardX.SummaryWriter( self._log_path) if tensorboardX is not None else None self._log_arguments() self._best_results = dict() # CUDA devices self._device = torch.device( "cuda" if torch.cuda.is_available() and not args.cpu else "cpu") self._gpu_count = torch.cuda.device_count() # set seed if args.seed is not None: util.set_seed(args.seed)