Exemplo n.º 1
0
def predict():
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)

    dev_dataset = Poetry.get_datasets(['dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    test_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                                batch_size=args.batch_size,
                                                shuffle=False)
    data_loader = DataLoader(dataset=dev_dataset,
                             batch_sampler=test_batch_sampler,
                             collate_fn=batchify_fn,
                             num_workers=0,
                             return_list=True)

    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    model.eval()
    vocab = tokenizer.vocab
    eos_id = vocab[tokenizer.sep_token]
    sos_id = vocab[tokenizer.cls_token]
    pad_id = vocab[tokenizer.pad_token]
    unk_id = vocab[tokenizer.unk_token]
    vocab_size = len(vocab)
    evaluated_sentences = []
    evaluated_sentences_ids = []
    logger.info("Predicting...")
    for data in data_loader:
        (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
         raw_tgt_labels) = data  # never use target when infer
        # Use greedy_search_infilling or beam_search_infilling to get predictions
        output_ids = beam_search_infilling(model,
                                           src_ids,
                                           src_sids,
                                           eos_id=eos_id,
                                           sos_id=sos_id,
                                           attn_id=attn_id,
                                           pad_id=pad_id,
                                           unk_id=unk_id,
                                           vocab_size=vocab_size,
                                           max_decode_len=args.max_decode_len,
                                           max_encode_len=args.max_encode_len,
                                           beam_width=args.beam_width,
                                           length_penalty=args.length_penalty,
                                           tgt_type_id=tgt_type_id)

        for source_ids, target_ids, predict_ids in zip(
                src_ids.numpy().tolist(),
                raw_tgt_labels.numpy().tolist(), output_ids.tolist()):
            if eos_id in predict_ids:
                predict_ids = predict_ids[:predict_ids.index(eos_id)]
            source_sentence = ''.join(
                map(post_process,
                    vocab.to_tokens(source_ids[1:source_ids.index(eos_id)])))
            tgt_sentence = ''.join(
                map(post_process,
                    vocab.to_tokens(target_ids[1:target_ids.index(eos_id)])))
            predict_ids = ''.join(
                map(post_process, vocab.to_tokens(predict_ids)))
            print("source :%s\ntarget :%s\npredict:%s\n" %
                  (source_sentence, tgt_sentence, predict_ids))
Exemplo n.º 2
0
def train():
    paddle.set_device("gpu" if args.n_gpu else "cpu")
    if paddle.distributed.get_world_size() > 1:
        paddle.distributed.init_parallel_env()

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)
    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    train_dataset, dev_dataset = Poetry.get_datasets(['train', 'dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len,
                                 noise_prob=args.noise_prob,
                                 use_random_noice=args.use_random_noice)

    train_dataset = train_dataset.apply(trans_func, lazy=True)
    train_batch_sampler = paddle.io.DistributedBatchSampler(
        train_dataset, batch_size=args.batch_size, shuffle=True)
    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))
    train_data_loader = DataLoader(dataset=train_dataset,
                                   batch_sampler=train_batch_sampler,
                                   collate_fn=batchify_fn,
                                   num_workers=0,
                                   return_list=True)

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    dev_data_loader = DataLoader(dataset=dev_dataset,
                                 batch_sampler=dev_batch_sampler,
                                 collate_fn=batchify_fn,
                                 num_workers=0,
                                 return_list=True)

    label_num = model.word_emb.weight.shape[0]
    if paddle.distributed.get_world_size() > 1:
        model = paddle.DataParallel(model)

    max_steps = len(train_data_loader) * args.num_epochs

    lr_scheduler = LinearDecayWithWarmup(args.learning_rate, max_steps,
                                         args.warmup_proportion)

    optimizer = paddle.optimizer.AdamW(
        learning_rate=lr_scheduler,
        epsilon=args.adam_epsilon,
        parameters=model.parameters(),
        weight_decay=args.weight_decay,
        grad_clip=nn.ClipGradByGlobalNorm(1.0),
        apply_decay_param_fun=lambda x: x in [
            p.name for n, p in model.named_parameters()
            if not any(nd in n for nd in ["bias", "norm"])
        ])

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    global_step = 1
    tic_train = time.time()
    for epoch in range(args.num_epochs):
        for step, batch in enumerate(train_data_loader, start=1):
            (src_ids, src_sids, src_pids, tgt_ids, tgt_sids, tgt_pids,
             attn_ids, mask_src_2_src, mask_tgt_2_srctgt,
             mask_attn_2_srctgtattn, tgt_labels, _) = batch
            # import pdb; pdb.set_trace()
            _, __, info = model(src_ids,
                                sent_ids=src_sids,
                                pos_ids=src_pids,
                                attn_bias=mask_src_2_src,
                                encode_only=True)
            cached_k, cached_v = info['caches']
            _, __, info = model(tgt_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_tgt_2_srctgt,
                                past_cache=(cached_k, cached_v),
                                encode_only=True)
            cached_k2, cached_v2 = info['caches']
            past_cache_k = [
                paddle.concat([k, k2], 1)
                for k, k2 in zip(cached_k, cached_k2)
            ]
            past_cache_v = [
                paddle.concat([v, v2], 1)
                for v, v2 in zip(cached_v, cached_v2)
            ]
            if args.label_smooth > 0.:
                tgt_labels = nn.functional.label_smooth(
                    nn.functional.one_hot(tgt_labels, label_num),
                    epsilon=args.label_smooth)
            loss, _, __ = model(attn_ids,
                                sent_ids=tgt_sids,
                                pos_ids=tgt_pids,
                                attn_bias=mask_attn_2_srctgtattn,
                                past_cache=(past_cache_k, past_cache_v),
                                tgt_labels=tgt_labels,
                                tgt_pos=paddle.nonzero(attn_ids == attn_id))
            if global_step % args.logging_steps == 0:
                if (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0:
                    logger.info(
                        "global step %d, epoch: %d, batch: %d, loss: %f, speed: %.2f step/s, lr: %.3e"
                        % (global_step, epoch, step, loss, args.logging_steps /
                           (time.time() - tic_train), lr_scheduler.get_lr()))
                tic_train = time.time()

            loss.backward()
            optimizer.step()
            lr_scheduler.step()
            optimizer.clear_gradients()
            if global_step % args.save_steps == 0 and (
                (not args.n_gpu > 1) or paddle.distributed.get_rank() == 0):
                evaluate(model, dev_data_loader, tokenizer, rouge1, rouge2,
                         attn_id, tgt_type_id, args)
                output_dir = os.path.join(args.output_dir,
                                          "model_%d" % global_step)
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = model._layers if isinstance(
                    model, paddle.DataParallel) else model
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)
            global_step += 1
Exemplo n.º 3
0
def evaluate():
    paddle.set_device("gpu" if args.use_gpu else "cpu")

    model = ErnieForGeneration.from_pretrained(args.model_name_or_path)
    if "ernie-tiny" in args.model_name_or_path:
        tokenizer = ErnieTinyTokenizer.from_pretrained(args.model_name_or_path)
    elif "ernie" in args.model_name_or_path:
        tokenizer = ErnieTokenizer.from_pretrained(args.model_name_or_path)
    elif "roberta" in args.model_name_or_path or "rbt" in args.model_name_or_path:
        tokenizer = RobertaTokenizer.from_pretrained(args.model_name_or_path)
    elif "electra" in args.model_name_or_path:
        tokenizer = ElectraTokenizer.from_pretrained(args.model_name_or_path)
    else:
        tokenizer = BertTokenizer.from_pretrained(args.model_name_or_path)

    dev_dataset = Poetry.get_datasets(['dev'])
    attn_id = tokenizer.vocab[
        '[ATTN]'] if '[ATTN]' in tokenizer.vocab else tokenizer.vocab['[MASK]']
    tgt_type_id = model.sent_emb.weight.shape[0] - 1

    trans_func = convert_example(tokenizer=tokenizer,
                                 attn_id=attn_id,
                                 tgt_type_id=tgt_type_id,
                                 max_encode_len=args.max_encode_len,
                                 max_decode_len=args.max_decode_len)

    batchify_fn = lambda samples, fn=Tuple(
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # src_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_pids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_sids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # attn_ids
        Pad(axis=0, pad_val=tokenizer.pad_token_id),  # tgt_labels
    ): after_padding(fn(samples))

    dev_dataset = dev_dataset.apply(trans_func, lazy=True)
    dev_batch_sampler = paddle.io.BatchSampler(dev_dataset,
                                               batch_size=args.batch_size,
                                               shuffle=False)
    data_loader = DataLoader(dataset=dev_dataset,
                             batch_sampler=dev_batch_sampler,
                             collate_fn=batchify_fn,
                             num_workers=0,
                             return_list=True)

    rouge1 = Rouge1()
    rouge2 = Rouge2()

    if args.init_checkpoint:
        model_state = paddle.load(args.init_checkpoint)
        model.set_state_dict(model_state)

    model.eval()
    vocab = tokenizer.vocab
    eos_id = vocab[tokenizer.sep_token]
    sos_id = vocab[tokenizer.cls_token]
    pad_id = vocab[tokenizer.pad_token]
    unk_id = vocab[tokenizer.unk_token]
    vocab_size = len(vocab)
    evaluated_sentences_ids = []
    reference_sentences_ids = []
    logger.info("Evaluating...")
    for data in tqdm(data_loader):
        (src_ids, src_sids, src_pids, _, _, _, _, _, _, _, _,
         raw_tgt_labels) = data  # never use target when infer
        # Use greedy_search_infilling or beam_search_infilling to get predictions
        output_ids = beam_search_infilling(model,
                                           src_ids,
                                           src_sids,
                                           eos_id=eos_id,
                                           sos_id=sos_id,
                                           attn_id=attn_id,
                                           pad_id=pad_id,
                                           unk_id=unk_id,
                                           vocab_size=vocab_size,
                                           max_decode_len=args.max_decode_len,
                                           max_encode_len=args.max_encode_len,
                                           beam_width=args.beam_width,
                                           length_penalty=args.length_penalty,
                                           tgt_type_id=tgt_type_id)

        for ids in output_ids.tolist():
            if eos_id in ids:
                ids = ids[:ids.index(eos_id)]
            evaluated_sentences_ids.append(ids)

        for ids in raw_tgt_labels.numpy().tolist():
            ids = ids[:ids.index(eos_id)]
            reference_sentences_ids.append(ids)

    score1 = rouge1.score(evaluated_sentences_ids, reference_sentences_ids)
    score2 = rouge2.score(evaluated_sentences_ids, reference_sentences_ids)

    logger.info("Rouge-1: %.5f ,Rouge-2: %.5f" % (score1 * 100, score2 * 100))