Exemple #1
0
    def save(self):
        """
        Save the trainer.

        `Trainer` parameters like epoch, best_so_far, model, optimizer
        and early_stopping will be savad to specific file path.

        :param path: Path to save trainer.

        """
        checkpoint = self._save_dir.joinpath('trainer.pt')
        logger.info(f" ** save trainer model to {checkpoint} ** ")
        if self._data_parallel:
            model = self._model.module.state_dict()
        else:
            model = self._model.state_dict()
        state = {
            'epoch': self._epoch,
            'model': model,
            'optimizer': self._optimizer.state_dict(),
            'early_stopping': self._early_stopping.state_dict(),
        }
        if self._epoch_scheduler:
            state['epoch_scheduler'] = self._epoch_scheduler.state_dict()
        if self._step_scheduler:
            state['step_scheduler'] = self._step_scheduler.state_dict()
        torch.save(state, checkpoint)
Exemple #2
0
def get_dataloader(train_data, valid_data, args):
    logger.info("使用Dataset和DataLoader对数据进行封装")
    train_dataset = PairDataset(train_data, num_neg=0)
    valid_dataset = PairDataset(valid_data, num_neg=0)
    padding = MultiQAPadding(fixed_length_uttr=args.fixed_length_uttr,
                             fixed_length_resp=args.fixed_length_resp,
                             fixed_length_turn=args.fixed_length_turn)

    train_dataloader = DictDataLoader(train_dataset,
                                      batch_size=args.batch_size,
                                      turns=args.fixed_length_turn,
                                      stage='train',
                                      shuffle=True,
                                      sort=False,
                                      callback=padding)
    valid_dataloader = DictDataLoader(valid_dataset,
                                      batch_size=args.batch_size,
                                      turns=args.fixed_length_turn,
                                      stage='dev',
                                      shuffle=False,
                                      sort=False,
                                      callback=padding)

    for i, (x, y) in enumerate(train_dataloader):
        # 打印Utterance的形状
        logger.info(f"The shape of utternace is {x[constants.UTTRS].shape}")
        if i == 0:
            break
    return train_dataloader, valid_dataloader
Exemple #3
0
    def evaluate(self, dataloader: DataLoader):
        result = dict()
        y_pred = self.predict(dataloader)
        y_true = dataloader.label
        if isinstance(self._task, tasks.Classification):
            y_pred_label = np.argmax(y_pred[:10], axis=-1)
            # 记录前10个真实标签和预测标签
            logger.info(
                f"The former 10 true label is {y_true[:10]}  | pred label is {y_pred_label}"
            )
        elif isinstance(self._task, tasks.Ranking):
            y_true = y_true.reshape(len(y_true), 1)
        loss = sum([
            c(torch.tensor(y_pred), torch.tensor(y_true))
            for c in self._criterions
        ])
        self._writer.add_scalar("Loss/eval", loss.item(), self._iteration)
        try:
            id_left = dataloader.id_left
        except:
            pass

        if isinstance(self._task, tasks.Ranking):
            for metric in self._task.metrics:
                result[metric] = self._eval_metric_on_data_frame(
                    metric, id_left, y_true, y_pred.squeeze(axis=-1))
        else:
            for metric in self._task.metrics:
                result[metric] = metric(y_true, y_pred)
        return result, loss
Exemple #4
0
 def save_model(self):
     """Save the model."""
     checkpoint = self._save_dir.joinpath('model.pt')
     logger.info(f" ** save raw model to {checkpoint} ** ")
     if self._data_parallel:
         torch.save(self._model.module.state_dict(), checkpoint)
     else:
         torch.save(self._model.state_dict(), checkpoint)
Exemple #5
0
    def restore_model(self, checkpoint: typing.Union[str, Path]):
        """
        Restore model.

        :param checkpoint: A checkpoint from which to continue training.

        """
        logger.info(" ** restore raw model ** ")
        state = torch.load(checkpoint, map_location=self._device)
        if self._data_parallel:
            self._model.module.load_state_dict(state)
        else:
            self._model.load_state_dict(state)
Exemple #6
0
    def _load_path(self, checkpoint: typing.Union[str, Path],
                   save_dir: typing.Union[str, Path]):
        logger.info(" ** restore exist model checkpoint ** ")
        if not save_dir:
            save_dir = Path('.').joinpath('save')
            if not Path(save_dir).exists():
                Path(save_dir).mkdir(parents=True)

        self._save_dir = Path(save_dir)

        if checkpoint:
            if self._save_all:
                self.restore(checkpoint)
            else:
                self.restore_model(checkpoint)
Exemple #7
0
    def __init__(self,
                 model: BaseModel,
                 optimizer: Optimizer,
                 trainloader: DataLoader,
                 validloader: DataLoader,
                 device: typing.Union[torch.device, int, list, None] = None,
                 writer: SummaryWriter = None,
                 start_epoch: int = 1,
                 epochs: int = 10,
                 validate_interval: typing.Optional[int] = None,
                 epoch_scheduler: typing.Any = None,
                 step_scheduler: typing.Any = None,
                 clip_norm: typing.Union[float, int] = None,
                 l1_reg: float = 0.0,
                 l2_reg: float = 0.0,
                 patience: typing.Optional[int] = None,
                 key: typing.Any = None,
                 checkpoint: typing.Union[str, Path] = None,
                 save_dir: typing.Union[str, Path] = None,
                 save_all: bool = False,
                 verbose: int = 1,
                 **kwargs):
        log_dir = Path(save_dir).joinpath("runs")
        self._writer = writer or SummaryWriter(log_dir=log_dir)
        logger.info(f" ** save tensorboard to {self._writer.get_logdir()} ** ")

        self._load_model(model, device)
        self._load_dataloader(trainloader, validloader, validate_interval)
        self._optimizer = optimizer
        self._epoch_scheduler = epoch_scheduler  # 由于这里的scheduler是在每个epoch之后调用一次,所以在定义的时候注意设置更新的step数和epoch一致
        self._step_scheduler = step_scheduler  # 这个是针对step的scheduler
        self._clip_norm = clip_norm  # 梯度裁剪
        # 正则化系数
        self._l1_reg = l1_reg
        self._l2_reg = l2_reg
        self._criterions = self._task.losses

        if not key:
            key = self._task.metrics[0]
        self._early_stopping = EarlyStopping(patience=patience, key=key)

        self._start_epoch = start_epoch
        self._epochs = epochs
        self._iteration = 0
        self._verbose = verbose
        self._save_all = save_all

        self._load_path(checkpoint, save_dir)
Exemple #8
0
    def restore(self, checkpoint: typing.Union[str, Path] = None):
        """
        Restore trainer.

        :param checkpoint: A checkpoint from which to continue training.

        """
        logger.info(" ** restore trainer model ** ")
        state = torch.load(checkpoint, map_location=self._device)
        if self._data_parallel:
            self._model.module.load_state_dict(state['model'])
        else:
            self._model.load_state_dict(state['model'])
        self._optimizer.load_state_dict(state['optimizer'])
        self._start_epoch = state['epoch'] + 1
        self._early_stopping.load_state_dict(state['early_stopping'])
        if self._epoch_scheduler:
            self._epoch_scheduler.load_state_dict(state['epoch_scheduler'])
        if self._step_scheduler:
            self._step_scheduler.load_state_dict(state['step_scheduler'])
Exemple #9
0
    def _load_model(self,
                    model: BaseModel,
                    device: typing.Union[torch.device, int, list,
                                         None] = None):
        if not isinstance(model, BaseModel):
            raise ValueError(f"model should be a `BaseModel` instance. "
                             f"But got {type(model)}")

        logger.info(" ** load model and device ** ")
        self._task = model.params['task']
        self._data_parallel = False
        self._model = model

        # 如果指定了多个GPU,则默认是数据并行
        if isinstance(device, list) and len(device):
            logger.info(" ** data parallel ** ")
            self._data_parallel = True
            self._model = torch.nn.DataParallel(self._model, device_ids=device)
            self._device = device[0]
        else:
            self._device = parse_device(device)
        self._model.to(self._device)
Exemple #10
0
def preprocess_train_and_val(train_file,
                             valid_file,
                             args,
                             remove_placeholder=False):
    logger.info("读取数据")
    train_data = pd.read_csv(train_file)
    valid_data = pd.read_csv(valid_file)
    train_data.dropna(axis=0,
                      subset=[constants.UTTRS, constants.RESP],
                      inplace=True)
    valid_data.dropna(axis=0,
                      subset=[constants.UTTRS, constants.RESP],
                      inplace=True)
    # 去除人设数据的占位符
    if remove_placeholder:
        logger.info("去除占位符的无效标记")
        placeholder = ['[', ']', '_']
        columns = [constants.UTTRS, constants.LAST, constants.RESP]
        for col in columns:
            train_data[col] = train_data[col].apply(
                lambda s: ''.join([c for c in s if c not in placeholder]))
            valid_data[col] = train_data[col].apply(
                lambda s: ''.join([c for c in s if c not in placeholder]))
    logger.info(
        f"训练集数据量为:{train_data.shape[0]} | 验证集数据量为:{valid_data.shape[0]}")
    load_fail = True
    # -----------------------------------

    logger.info("使用Preprocessor处理数据")
    # 是否加载之前训练好的预处理工具
    if args.is_load_preprocessor:
        try:
            preprocessor = load_preprocessor(args.load_preprocessor_path)
            logger.info(f"成功从 {args.load_preprocessor_path} 加载了预处理器")
            load_fail = False
        except:
            load_fail = True
    if load_fail:
        logger.info("基于现有数据训练预处理器")
        preprocessor = CNPreprocessorForMultiQA()
        preprocessor = preprocessor.fit(
            train_data, columns=[constants.LAST, constants.RESP])
        preprocessor.save(args.load_preprocessor_path)

    train_data = preprocessor.transform(train_data,
                                        uttr_col=constants.UTTRS,
                                        resp_col=constants.RESP)
    valid_data = preprocessor.transform(valid_data,
                                        uttr_col=constants.UTTRS,
                                        resp_col=constants.RESP)
    use_cols = [
        'D_num', 'turns', 'utterances', 'response', 'utterances_len',
        'response_len', 'label'
    ]
    train_data = train_data[use_cols]
    valid_data = valid_data[use_cols]

    logger.info(
        f"处理之后——训练集数据量为:{train_data.shape[0]} | 验证集数据量为:{valid_data.shape[0]}")

    return train_data, valid_data, preprocessor
Exemple #11
0
                        help="Whether to load preprocessor.")
    parser.add_argument('--remove_preprocessor',
                        action='store_true',
                        help="Whether to remove the exist preprocessor.")

    # 一些配置参数
    parser.add_argument('--cuda_num',
                        default='0,1,2,3',
                        type=str,
                        help="The number of the used CUDA.")

    args = parser.parse_args()
    os.environ['CUDA_VISIBLE_DEVICES'] = args.cuda_num

    logger.info(
        f"----------------------- 训练 {args.model_name.upper()}: {time.ctime()} --------------------------"
    )
    seed_everything(args.seed)

    # 记录参数
    logger.info(
        f"使用参数为 —— \n L2_REG: {args.l2_reg} | Epochs: {args.epochs} | Batch Size: {args.batch_size} | LR: {args.lr} | "
        f"Uttr_Len: {args.fixed_length_uttr} | Resp_Len: {args.fixed_length_resp} | Turn: {args.fixed_length_turn}"
    )

    # 是否删除之前训练好的预处理器
    ## 删除文件夹以及其中的文件
    if args.remove_preprocessor:
        try:
            shutil.rmtree(args.load_preprocessor_path)
        except Exception as e:
Exemple #12
0
    def _run_epoch(self):
        """
        Run each epoch.

        The training steps:
            - Get batch and feed them into model
            - Get outputs. Caculate all losses and sum them up
            - Loss backwards and optimizer steps
            - Evaluation
            - Update and output result
        """
        num_batch = len(self._trainloader)
        train_loss = AverageMeter()
        with tqdm(enumerate(self._trainloader),
                  total=num_batch,
                  disable=not self._verbose) as pbar:
            for step, (inputs, target) in pbar:
                # Run Train
                outputs = self._model(inputs)
                ## 计算所有的loss并相加
                loss = sum([c(outputs, target) for c in self._criterions])

                self._backward(loss)
                ## 更新loss的值
                train_loss.update(loss.item())

                ## 设置Progress Bar
                pbar.set_description(f"Epoch {self._epoch}/{self._epochs}")
                pbar.set_postfix(loss=f"{loss.item():.3f}")
                self._writer.add_scalar("Loss/train", loss.item(),
                                        self._iteration)

                # Run Evaluate
                self._iteration += 1
                if self._iteration % self._validate_interval == 0:
                    pbar.update(1)
                    if self._verbose:
                        pbar.write(f"[Iter-{self._iteration} "
                                   f"Loss-{train_loss.avg:.3f}]")
                    ## 更新validloader评估的结果
                    result, eval_loss = self.evaluate(self._validloader)
                    m_string = ""
                    for metric in result:
                        res = result[metric]
                        if not isinstance(metric, str):
                            metric = metric.ALIAS[0]
                        self._writer.add_scalar(f"{metric}/eval", res,
                                                self._iteration)
                        m_string += f"| {metric}: {res} "

                    logger.info(
                        f"Epoch: {self._epoch} | Train Loss: {loss.item():.3f} | "
                        f"Eval loss: {eval_loss.item(): .3f} " + m_string)
                    if self._verbose:
                        pbar.write(" Validation: " +
                                   '-'.join(f"{k}: {round(v, 4)}"
                                            for k, v in result.items()))
                    ## Early Stopping
                    self._early_stopping.update(result)
                    if self._early_stopping.should_stop_early:
                        self._save()
                        pbar.write("Ran out of patience. Stop training...")
                        break
                    elif self._early_stopping.is_best_so_far:
                        self._save()