Exemplo n.º 1
0
def main():
    args = parse_args()

    predictor = Predictor.create_predictor(args)
    test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
        args)
    _, vocab = IWSLT15.get_vocab()
    trg_idx2word = vocab.idx_to_token

    predictor.predict(test_loader, args.infer_output_file, trg_idx2word,
                      bos_id, eos_id)
Exemplo n.º 2
0
def do_predict(args):
    device = paddle.set_device("gpu" if args.use_gpu else "cpu")

    test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
        args)
    _, vocab = IWSLT15.get_vocab()

    trg_idx2word = vocab.idx_to_token

    model = paddle.Model(
        Seq2SeqAttnInferModel(
            src_vocab_size,
            tgt_vocab_size,
            args.hidden_size,
            args.hidden_size,
            args.num_layers,
            args.dropout,
            bos_id=bos_id,
            eos_id=eos_id,
            beam_size=args.beam_size,
            max_out_len=256))

    model.prepare()

    # Load the trained model
    assert args.init_from_ckpt, (
        "Please set reload_model to load the infer model.")
    model.load(args.init_from_ckpt)

    cand_list = []
    with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
        for data in test_loader():
            with paddle.no_grad():
                finished_seq = model.predict_batch(inputs=data)[0]
            finished_seq = finished_seq[:, :, np.newaxis] if len(
                finished_seq.shape) == 2 else finished_seq
            finished_seq = np.transpose(finished_seq, [0, 2, 1])
            for ins in finished_seq:
                for beam_idx, beam in enumerate(ins):
                    id_list = post_process_seq(beam, bos_id, eos_id)
                    word_list = [trg_idx2word[id] for id in id_list]
                    sequence = " ".join(word_list) + "\n"
                    f.write(sequence)
                    cand_list.append(word_list)
                    break

    test_ds = IWSLT15.get_datasets(["test"])

    bleu = BLEU()
    for i, data in enumerate(test_ds):
        ref = data[1].split()
        bleu.add_inst(cand_list[i], [ref])
    print("BLEU score is %s." % bleu.score())
Exemplo n.º 3
0
def create_infer_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    trans_func_tuple = IWSLT15.get_default_transform_func()
    test_ds = IWSLT15.get_datasets(
        mode=["test"], transform_func=[trans_func_tuple])
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(
        test_ds,
        batch_sampler=test_batch_sampler,
        collate_fn=partial(
            prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
    return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
Exemplo n.º 4
0
def create_train_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    train_ds, dev_ds = IWSLT15.get_datasets(
        mode=["train", "dev"],
        transform_func=[trans_func_tuple, trans_func_tuple])

    key = (lambda x, data_source: len(data_source[x][0]))
    cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])

    train_ds = train_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    dev_ds = dev_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    dev_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    train_loader = paddle.io.DataLoader(
        train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    dev_loader = paddle.io.DataLoader(
        dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id