Esempio n. 1
0
def do_predict(args):
    device = paddle.set_device("gpu" if args.use_gpu else "cpu")

    test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
        args)
    _, vocab = IWSLT15.get_vocab()

    trg_idx2word = vocab.idx_to_token

    model = paddle.Model(
        Seq2SeqAttnInferModel(
            src_vocab_size,
            tgt_vocab_size,
            args.hidden_size,
            args.hidden_size,
            args.num_layers,
            args.dropout,
            bos_id=bos_id,
            eos_id=eos_id,
            beam_size=args.beam_size,
            max_out_len=256))

    model.prepare()

    # Load the trained model
    assert args.init_from_ckpt, (
        "Please set reload_model to load the infer model.")
    model.load(args.init_from_ckpt)

    cand_list = []
    with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
        for data in test_loader():
            with paddle.no_grad():
                finished_seq = model.predict_batch(inputs=data)[0]
            finished_seq = finished_seq[:, :, np.newaxis] if len(
                finished_seq.shape) == 2 else finished_seq
            finished_seq = np.transpose(finished_seq, [0, 2, 1])
            for ins in finished_seq:
                for beam_idx, beam in enumerate(ins):
                    id_list = post_process_seq(beam, bos_id, eos_id)
                    word_list = [trg_idx2word[id] for id in id_list]
                    sequence = " ".join(word_list) + "\n"
                    f.write(sequence)
                    cand_list.append(word_list)
                    break

    test_ds = IWSLT15.get_datasets(["test"])

    bleu = BLEU()
    for i, data in enumerate(test_ds):
        ref = data[1].split()
        bleu.add_inst(cand_list[i], [ref])
    print("BLEU score is %s." % bleu.score())
Esempio n. 2
0
def create_infer_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    trans_func_tuple = IWSLT15.get_default_transform_func()
    test_ds = IWSLT15.get_datasets(
        mode=["test"], transform_func=[trans_func_tuple])
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size)

    test_loader = paddle.io.DataLoader(
        test_ds,
        batch_sampler=test_batch_sampler,
        collate_fn=partial(
            prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))
    return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
Esempio n. 3
0
def main():
    args = parse_args()

    predictor = Predictor.create_predictor(args)
    test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader(
        args)
    _, vocab = IWSLT15.get_vocab()
    trg_idx2word = vocab.idx_to_token

    predictor.predict(test_loader, args.infer_output_file, trg_idx2word,
                      bos_id, eos_id)
Esempio n. 4
0
def create_train_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    train_ds, dev_ds = IWSLT15.get_datasets(
        mode=["train", "dev"],
        transform_func=[trans_func_tuple, trans_func_tuple])

    key = (lambda x, data_source: len(data_source[x][0]))
    cut_fn = lambda data: (data[0][:max_len], data[1][:max_len])

    train_ds = train_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    dev_ds = dev_ds.filter(
        lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn)
    train_batch_sampler = SamplerHelper(train_ds).shuffle().sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    dev_batch_sampler = SamplerHelper(dev_ds).sort(
        key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size)

    train_loader = paddle.io.DataLoader(
        train_ds,
        batch_sampler=train_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    dev_loader = paddle.io.DataLoader(
        dev_ds,
        batch_sampler=dev_batch_sampler,
        collate_fn=partial(
            prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id))

    return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id
Esempio n. 5
0
    def predict(self, dataloader, infer_output_file, trg_idx2word, bos_id,
                eos_id):
        cand_list = []
        with io.open(infer_output_file, 'w', encoding='utf-8') as f:
            for data in dataloader():
                finished_seq = self.predict_batch(data)[0]
                finished_seq = finished_seq[:, :, np.newaxis] if len(
                    finished_seq.shape) == 2 else finished_seq
                finished_seq = np.transpose(finished_seq, [0, 2, 1])
                for ins in finished_seq:
                    for beam_idx, beam in enumerate(ins):
                        id_list = post_process_seq(beam, bos_id, eos_id)
                        word_list = [trg_idx2word[id] for id in id_list]
                        sequence = " ".join(word_list) + "\n"
                        f.write(sequence)
                        cand_list.append(word_list)
                        break

        test_ds = IWSLT15.get_datasets(["test"])
        bleu = BLEU()
        for i, data in enumerate(test_ds):
            ref = data[1].split()
            bleu.add_inst(cand_list[i], [ref])
        print("BLEU score is %s." % bleu.score())
Esempio n. 6
0
# See the License for the specific language governing permissions and
# limitations under the License.

import io
import os

from functools import partial
import numpy as np

import paddle
from paddlenlp.data import Vocab, Pad
from paddlenlp.data import SamplerHelper

from paddlenlp.datasets import IWSLT15

trans_func_tuple = IWSLT15.get_default_transform_func()


def create_train_loader(args):
    batch_size = args.batch_size
    max_len = args.max_len
    src_vocab, tgt_vocab = IWSLT15.get_vocab()
    bos_id = src_vocab[src_vocab.bos_token]
    eos_id = src_vocab[src_vocab.eos_token]
    pad_id = eos_id

    train_ds, dev_ds = IWSLT15.get_datasets(
        mode=["train", "dev"],
        transform_func=[trans_func_tuple, trans_func_tuple])

    key = (lambda x, data_source: len(data_source[x][0]))