예제 #1
0
파일: train.py 프로젝트: wjunneng/PaddleNLP
def train():
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len,
                                 noise_prob=args.noise_prob,
                                 use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_dataset,
                                 batch_sampler=dev_batch_sampler,
                                 collate_fn=batchify_fn,
                                 num_workers=0,
                                 return_list=True)

    label_num = model.word_emb.weight.shape[0]
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids,
             attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
             mask_attn_2_srctgtattn, tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            _, __, info = model(src_ids,
                                sent_ids=src_sids,
                                pos_ids=src_pids,
                                attn_bias=mask_src_2_src,
                                encode_only=True)
            cached_k, cached_v = info['caches']
            _, __, info = model(tgt_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_tgt_2_srctgt,
                                past_cache=(cached_k, cached_v),
                                encode_only=True)
            cached_k2, cached_v2 = info['caches']
            past_cache_k = [
                paddle.concat([k, k2], 1)
                for k, k2 in zip(cached_k, cached_k2)
            ]
            past_cache_v = [
                paddle.concat([v, v2], 1)
                for v, v2 in zip(cached_v, cached_v2)
            ]
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            loss, _, __ = model(attn_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_attn_2_srctgtattn,
                                past_cache=(past_cache_k, past_cache_v),
                                tgt_labels=tgt_labels,
                                tgt_pos=paddle.nonzero(attn_ids == attn_id))
            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
            if global_step % args.save_steps == 0 and (
                (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0):
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
예제 #2
0
파일: eval.py 프로젝트: wbj0110/models
def evaluate():
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)

    dev_dataset = Poetry.get_datasets(['dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    data_loader = DataLoader(dataset=dev_dataset,
                             batch_sampler=dev_batch_sampler,
                             collate_fn=batchify_fn,
                             num_workers=0,
                             return_list=True)

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    model.eval()
    vocab = tokenizer.vocab
    eos_id = vocab[tokenizer.sep_token]
    sos_id = vocab[tokenizer.cls_token]
    pad_id = vocab[tokenizer.pad_token]
    unk_id = vocab[tokenizer.unk_token]
    vocab_size = len(vocab)
    evaluated_sentences_ids = []
    reference_sentences_ids = []
    logger.info("Evaluating...")
    for data in tqdm(data_loader):
        (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
         raw_tgt_labels) = data  # never use target when infer
        # Use greedy_search_infilling or beam_search_infilling to get predictions
        output_ids = beam_search_infilling(model,
                                           src_ids,
                                           src_sids,
                                           eos_id=eos_id,
                                           sos_id=sos_id,
                                           attn_id=attn_id,
                                           pad_id=pad_id,
                                           unk_id=unk_id,
                                           vocab_size=vocab_size,
                                           max_decode_len=args.max_decode_len,
                                           max_encode_len=args.max_encode_len,
                                           beam_width=args.beam_width,
                                           length_penalty=args.length_penalty,
                                           tgt_type_id=tgt_type_id)

        for ids in output_ids.tolist():
            if eos_id in ids:
                ids = ids[:ids.index(eos_id)]
            evaluated_sentences_ids.append(ids)

        for ids in raw_tgt_labels.numpy().tolist():
            ids = ids[:ids.index(eos_id)]
            reference_sentences_ids.append(ids)

    score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
    score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids)

    logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100))
예제 #3
0
def train():
    paddle.set_device(args.device)
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = load_dataset(
        'poetry', splits=('train', 'dev'), lazy=False)
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(
        tokenizer=tokenizer,
        attn_id=attn_id,
        tgt_type_id=tgt_type_id,
        max_encode_len=args.max_encode_len,
        max_decode_len=args.max_decode_len,
        noise_prob=args.noise_prob,
        use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.map(trans_func)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # src_tids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_type_id),  # tgt_tids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(
        dataset=train_dataset,
        batch_sampler=train_batch_sampler,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    dev_dataset = dev_dataset.map(trans_func)
    dev_data_loader = DataLoader(
        dataset=dev_dataset,
        batch_size=args.batch_size,
        collate_fn=batchify_fn,
        num_workers=0,
        return_list=True)

    label_num = model.word_emb.weight.shape[0]
    train_model = StackModel(model)
    if paddle.distributed.get_world_size() > 1:
        # All 'forward' outputs derived from the module parameters using in DataParallel
        # must participate in the calculation of losses and subsequent gradient calculations.
        # So we use StackModel here to make the model only output loss in its 'forward' function.
        train_model = paddle.DataParallel(train_model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    # Generate parameter names needed to perform weight decay.
    # All bias and LayerNorm parameters are excluded.
    decay_params = [
        p.name for n, p in model.named_parameters()
        if not any(nd in n for nd in ["bias", "norm"])
    ]
    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in decay_params)

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids,
             mask_src_2_src, mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
             tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            tgt_pos = paddle.nonzero(attn_ids == attn_id)
            loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids,
                               tgt_pids, attn_ids, mask_src_2_src,
                               mask_tgt_2_srctgt, mask_attn_2_srctgtattn,
                               tgt_labels, tgt_pos)
            if global_step % args.logging_steps == 0:
                if paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_grad()
            if global_step % args.save_steps == 0 and paddle.distributed.get_rank(
            ) == 0:
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
예제 #4
0
파일: module.py 프로젝트: houj04/PaddleHub
    def finetune(
        self,
        train_path,
        dev_path=None,
        save_dir="ernie_gen_result",
        init_ckpt_path=None,
        use_gpu=True,
        max_steps=500,
        batch_size=8,
        max_encode_len=50,
        max_decode_len=50,
        learning_rate=5e-5,
        warmup_proportion=0.1,
        weight_decay=0.1,
        noise_prob=0,
        label_smooth=0,
        beam_width=5,
        length_penalty=1.0,
        log_interval=100,
        save_interval=200,
    ):
        """
        finetune with the specified dataset.

        Args:
            train_path(str): the train dataset path.
            dev_path(str): the dev dataset path.
            save_dir(str): the model params and dev dataset predict result save path.
            init_ckpt_path(str): incremental training load path.
            use_gpu(bool): use gpu or not.
            max_steps(int): max training steps.
            batch_size(int): the batch size.
            max_encode_len(int): the max encode length.
            max_decode_len(int): the max decode length.
            learning_rate(float): the learning rate.
            warmup_proportion(float): the warmup proportion.
            weight_decay(float): the weight decay magnitude.
            noise_prob(float): the nosie probability. see the ernie gen paper for details.
            label_smooth(float): the label smooth magnitude.
            beam_width(int): the beam size during evaluating the dev dataset.
            length_penalty(float): the length penalty during evaluating the dev dataset.
            log_interval(int): the log interval.
            save_interval(int): the save interval. dev set will be evaluated after saving.

        Return:
            result(dict): A Dictionary of shape::
                {
                    last_save_path(str): last model save path.
                    last_ppl(float): last model ppl.
                }
        """
        paddle.disable_static()
        paddle.set_device('gpu') if use_gpu else paddle.set_device('cpu')

        if init_ckpt_path is not None:
            logger.info('loading checkpoint from %s' % init_ckpt_path)
            sd = paddle.load(init_ckpt_path)
            self.model.set_state_dict(sd)

        train_dataset = self._load_dataset(train_path)
        attn_id = self.tokenizer.vocab['[MASK]']
        trans_func = convert_example(tokenizer=self.tokenizer,
                                     attn_id=attn_id,
                                     tgt_type_id=1,
                                     max_encode_len=max_encode_len,
                                     max_decode_len=max_decode_len,
                                     noise_prob=noise_prob)

        train_dataset = train_dataset.map(trans_func)
        train_batch_sampler = paddle.io.BatchSampler(train_dataset, batch_size=batch_size, shuffle=True)
        batchify_fn = lambda samples, fn=Tuple(
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # src_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # src_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_pids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_type_id),  # tgt_tids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # attn_ids
            Pad(axis=0, pad_val=self.tokenizer.pad_token_id),  # tgt_labels
        ): after_padding(fn(samples))
        train_data_loader = DataLoader(dataset=train_dataset,
                                       batch_sampler=train_batch_sampler,
                                       collate_fn=batchify_fn,
                                       num_workers=0,
                                       return_list=True)

        if dev_path:
            dev_dataset = self._load_dataset(dev_path)
            dev_dataset = dev_dataset.map(trans_func)
            dev_data_loader = DataLoader(dataset=dev_dataset,
                                         batch_size=batch_size,
                                         collate_fn=batchify_fn,
                                         num_workers=0,
                                         return_list=True)

        label_num = self.model.word_emb.weight.shape[0]
        train_model = StackModel(self.model)
        lr_scheduler = LinearDecayWithWarmup(learning_rate, max_steps, warmup_proportion)
        # Generate parameter names needed to perform weight decay.
        # All bias and LayerNorm parameters are excluded.
        decay_params = [p.name for n, p in self.model.named_parameters() if not any(nd in n for nd in ["bias", "norm"])]
        optimizer = paddle.optimizer.AdamW(learning_rate=lr_scheduler,
                                           parameters=self.model.parameters(),
                                           weight_decay=weight_decay,
                                           grad_clip=nn.ClipGradByGlobalNorm(1.0),
                                           apply_decay_param_fun=lambda x: x in decay_params)

        rouge1 = Rouge1()
        rouge2 = Rouge2()
        global_step = 1
        if save_dir and not os.path.exists(save_dir):
            os.makedirs(save_dir)
        while True:
            for batch in train_data_loader:
                (src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
                 mask_attn_2_srctgtattn, tgt_labels, _) = batch
                if label_smooth > 0.:
                    tgt_labels = nn.functional.label_smooth(nn.functional.one_hot(tgt_labels, label_num),
                                                            epsilon=label_smooth)

                tgt_pos = paddle.nonzero(attn_ids == attn_id)
                loss = train_model(src_ids, src_tids, src_pids, tgt_ids, tgt_tids, tgt_pids, attn_ids, mask_src_2_src,
                                   mask_tgt_2_srctgt, mask_attn_2_srctgtattn, tgt_labels, tgt_pos)

                loss.backward()
                optimizer.step()
                lr_scheduler.step()
                optimizer.clear_grad()

                if global_step % log_interval == 0 and paddle.distributed.get_rank() == 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    logger.info('[step %d / %d]train loss %.5f, ppl %.5f, elr %.3e' %
                                (global_step, max_steps, loss_np, ppl, lr_scheduler.get_lr()))
                if save_dir and global_step % save_interval == 0 and global_step > 0:
                    loss_np = loss.numpy()
                    ppl = np.exp(loss_np)
                    save_name = "step_%s_ppl_%.5f.params" % (global_step, ppl)
                    save_path = os.path.join(save_dir, save_name)
                    logger.info("save the model in %s" % save_path)
                    paddle.save(self.model.state_dict(), save_path)

                    if dev_path:
                        self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id,
                                       max_decode_len, max_encode_len, beam_width, length_penalty)

                if global_step >= max_steps:
                    break
                global_step += 1

            if global_step >= max_steps:
                break

        if global_step % save_interval != 0:
            loss_np = loss.numpy()
            ppl = np.exp(loss_np)
            logger.info('[final step %d]train loss %.5f, ppl %.5f, elr %.3e' %
                        (global_step, loss_np, ppl, lr_scheduler.get_lr()))
            if save_dir:
                save_name = "step_%s_ppl_%.5f.pdparams" % (global_step, ppl)
                save_path = os.path.join(save_dir, save_name)
                logger.info("save the model in %s" % save_path)
                paddle.save(self.model.state_dict(), save_path)

                if dev_path:
                    self._evaluate(self.model, dev_data_loader, self.tokenizer, rouge1, rouge2, attn_id, max_decode_len,
                                   max_encode_len, beam_width, length_penalty)

        result = {
            "last_save_path": "%s" % save_path,
            "last_ppl": ppl[0],
        }

        return result