Exemplo n.º 1
0
def do_predict(args):
    device = paddle.set_device("gpu" if args.use_cuda else "cpu")
    fluid.enable_dygraph(device) if args.eager_run else None

    inputs = [
        Input([None, None], "int64", name="src_word"),
        Input([None, None], "int64", name="src_pos"),
        Input([None, args.n_head, None, None],
              "float32",
              name="src_slf_attn_bias"),
        Input([None, args.n_head, None, None],
              "float32",
              name="trg_src_attn_bias"),
    ]

    # define data
    dataset = Seq2SeqDataset(fpattern=args.predict_file,
                             src_vocab_fpath=args.src_vocab_fpath,
                             trg_vocab_fpath=args.trg_vocab_fpath,
                             token_delimiter=args.token_delimiter,
                             start_mark=args.special_token[0],
                             end_mark=args.special_token[1],
                             unk_mark=args.special_token[2],
                             byte_data=True)
    args.src_vocab_size, args.trg_vocab_size, args.bos_idx, args.eos_idx, \
        args.unk_idx = dataset.get_vocab_summary()
    trg_idx2word = Seq2SeqDataset.load_dict(dict_path=args.trg_vocab_fpath,
                                            reverse=True,
                                            byte_data=True)
    batch_sampler = Seq2SeqBatchSampler(dataset=dataset,
                                        use_token_batch=False,
                                        batch_size=args.batch_size,
                                        max_length=args.max_length)
    data_loader = DataLoader(dataset=dataset,
                             batch_sampler=batch_sampler,
                             places=device,
                             collate_fn=partial(prepare_infer_input,
                                                bos_idx=args.bos_idx,
                                                eos_idx=args.eos_idx,
                                                src_pad_idx=args.eos_idx,
                                                n_head=args.n_head),
                             num_workers=0,
                             return_list=True)

    # define model
    model = paddle.Model(
        InferTransformer(args.src_vocab_size,
                         args.trg_vocab_size,
                         args.max_length + 1,
                         args.n_layer,
                         args.n_head,
                         args.d_key,
                         args.d_value,
                         args.d_model,
                         args.d_inner_hid,
                         args.prepostprocess_dropout,
                         args.attention_dropout,
                         args.relu_dropout,
                         args.preprocess_cmd,
                         args.postprocess_cmd,
                         args.weight_sharing,
                         args.bos_idx,
                         args.eos_idx,
                         beam_size=args.beam_size,
                         max_out_len=args.max_out_len), inputs)
    model.prepare()

    # load the trained model
    assert args.init_from_params, (
        "Please set init_from_params to load the infer model.")
    model.load(args.init_from_params)

    # TODO: use model.predict when support variant length
    f = open(args.output_file, "wb")
    for data in data_loader():
        finished_seq = model.test_batch(inputs=flatten(data))[0]
        finished_seq = np.transpose(finished_seq, [0, 2, 1])
        for ins in finished_seq:
            for beam_idx, beam in enumerate(ins):
                if beam_idx >= args.n_best: break
                id_list = post_process_seq(beam, args.bos_idx, args.eos_idx)
                word_list = [trg_idx2word[id] for id in id_list]
                sequence = b" ".join(word_list) + b"\n"
                f.write(sequence)
Exemplo n.º 2
0
def do_predict(args):
    device = paddle.set_device("gpu" if args.use_gpu else "cpu")
    fluid.enable_dygraph(device) if args.eager_run else None

    # define model
    inputs = [
        Input(
            [None, None], "int64", name="src_word"),
        Input(
            [None], "int64", name="src_length"),
    ]

    # def dataloader
    dataset = Seq2SeqDataset(
        fpattern=args.infer_file,
        src_vocab_fpath=args.vocab_prefix + "." + args.src_lang,
        trg_vocab_fpath=args.vocab_prefix + "." + args.tar_lang,
        token_delimiter=None,
        start_mark="<s>",
        end_mark="</s>",
        unk_mark="<unk>")
    trg_idx2word = Seq2SeqDataset.load_dict(
        dict_path=args.vocab_prefix + "." + args.tar_lang, reverse=True)
    (args.src_vocab_size, args.trg_vocab_size, bos_id, eos_id,
     unk_id) = dataset.get_vocab_summary()
    batch_sampler = Seq2SeqBatchSampler(
        dataset=dataset, use_token_batch=False, batch_size=args.batch_size)
    data_loader = DataLoader(
        dataset=dataset,
        batch_sampler=batch_sampler,
        places=device,
        collate_fn=partial(
            prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=eos_id),
        num_workers=0,
        return_list=True)

    model_maker = AttentionInferModel if args.attention else BaseInferModel
    model = paddle.Model(
        model_maker(
            args.src_vocab_size,
            args.tar_vocab_size,
            args.hidden_size,
            args.hidden_size,
            args.num_layers,
            args.dropout,
            bos_id=bos_id,
            eos_id=eos_id,
            beam_size=args.beam_size,
            max_out_len=256),
        inputs=inputs)

    model.prepare()

    # load the trained model
    assert args.reload_model, (
        "Please set reload_model to load the infer model.")
    model.load(args.reload_model)

    # TODO(guosheng): use model.predict when support variant length
    with io.open(args.infer_output_file, 'w', encoding='utf-8') as f:
        for data in data_loader():
            finished_seq = model.test_batch(inputs=flatten(data))[0]
            finished_seq = finished_seq[:, :, np.newaxis] if len(
                finished_seq.shape) == 2 else finished_seq
            finished_seq = np.transpose(finished_seq, [0, 2, 1])
            for ins in finished_seq:
                for beam_idx, beam in enumerate(ins):
                    id_list = post_process_seq(beam, bos_id, eos_id)
                    word_list = [trg_idx2word[id] for id in id_list]
                    sequence = " ".join(word_list) + "\n"
                    f.write(sequence)
                    break