示例#1
0
    def _save_model(self, save_path: str, model: PreTrainedModel, tokenizer: PreTrainedTokenizer,
                    iteration: int, optimizer: Optimizer = None, save_as_best: bool = False,
                    extra: dict = None, include_iteration: int = True, name: str = 'model'):
        extra_state = dict(iteration=iteration)

        if optimizer:
            extra_state['optimizer'] = optimizer.state_dict()

        if extra:
            extra_state.update(extra)

        if save_as_best:
            dir_path = os.path.join(save_path, '%s_best' % name)
        else:
            dir_name = '%s_%s' % (name, iteration) if include_iteration else name
            dir_path = os.path.join(save_path, dir_name)

        util.create_directories_dir(dir_path)

        # save model
        if isinstance(model, DataParallel):
            model.module.save_pretrained(dir_path)
        else:
            model.save_pretrained(dir_path)

        # save vocabulary
        tokenizer.save_pretrained(dir_path)

        # save extra
        state_path = os.path.join(dir_path, 'extra.state')
        torch.save(extra_state, state_path)
示例#2
0
    def __init__(self, args: argparse.Namespace):
        self._args = args
        self._debug = self._args.debug

        run_key = str(datetime.datetime.now()).replace(' ', '_')
        print("*****", "***", self._args)
        if hasattr(args, 'save_path'):
            self._save_path = os.path.join(
                self._args.save_path, self._args.label,
                str(self._args.rel_filter_threshold), str(self._args.epochs))
            util.create_directories_dir(self._save_path)

        # logging
        if hasattr(args, 'log_path'):
            self._log_path = os.path.join(self._args.log_path,
                                          self._args.label,
                                          str(self._args.rel_filter_threshold),
                                          str(self._args.epochs))
            util.create_directories_dir(self._log_path)

            self._log_paths = dict()

            # file + console logging
            log_formatter = logging.Formatter(
                "%(asctime)s [%(threadName)-12.12s] [%(levelname)-5.5s]  %(message)s"
            )
            self._logger = logging.getLogger()
            util.reset_logger(self._logger)

            file_handler = logging.FileHandler(
                os.path.join(self._log_path, 'all.log'))
            file_handler.setFormatter(log_formatter)
            self._logger.addHandler(file_handler)

            console_handler = logging.StreamHandler(sys.stdout)
            console_handler.setFormatter(log_formatter)
            self._logger.addHandler(console_handler)

            if self._debug:
                self._logger.setLevel(logging.DEBUG)
            else:
                self._logger.setLevel(logging.INFO)

            # tensorboard summary
            self._summary_writer = tensorboardX.SummaryWriter(
                self._log_path) if tensorboardX is not None else None

            self._log_arguments()

        self._best_results = dict()

        # CUDA devices
        self._device = torch.device(
            "cuda" if torch.cuda.is_available() and not args.cpu else "cpu")
        self._gpu_count = torch.cuda.device_count()

        # set seed
        if args.seed is not None:
            util.set_seed(args.seed)