def main(): args = parse_args() predictor = Predictor.create_predictor(args) test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( args) _, vocab = IWSLT15.get_vocab() trg_idx2word = vocab.idx_to_token predictor.predict(test_loader, args.infer_output_file, trg_idx2word, bos_id, eos_id)
def do_predict(args): device = paddle.set_device("gpu" if args.use_gpu else "cpu") test_loader, src_vocab_size, tgt_vocab_size, bos_id, eos_id = create_infer_loader( args) _, vocab = IWSLT15.get_vocab() trg_idx2word = vocab.idx_to_token model = paddle.Model( Seq2SeqAttnInferModel( src_vocab_size, tgt_vocab_size, args.hidden_size, args.hidden_size, args.num_layers, args.dropout, bos_id=bos_id, eos_id=eos_id, beam_size=args.beam_size, max_out_len=256)) model.prepare() # Load the trained model assert args.init_from_ckpt, ( "Please set reload_model to load the infer model.") model.load(args.init_from_ckpt) cand_list = [] with io.open(args.infer_output_file, 'w', encoding='utf-8') as f: for data in test_loader(): with paddle.no_grad(): finished_seq = model.predict_batch(inputs=data)[0] finished_seq = finished_seq[:, :, np.newaxis] if len( finished_seq.shape) == 2 else finished_seq finished_seq = np.transpose(finished_seq, [0, 2, 1]) for ins in finished_seq: for beam_idx, beam in enumerate(ins): id_list = post_process_seq(beam, bos_id, eos_id) word_list = [trg_idx2word[id] for id in id_list] sequence = " ".join(word_list) + "\n" f.write(sequence) cand_list.append(word_list) break test_ds = IWSLT15.get_datasets(["test"]) bleu = BLEU() for i, data in enumerate(test_ds): ref = data[1].split() bleu.add_inst(cand_list[i], [ref]) print("BLEU score is %s." % bleu.score())
def create_infer_loader(args): batch_size = args.batch_size max_len = args.max_len trans_func_tuple = IWSLT15.get_default_transform_func() test_ds = IWSLT15.get_datasets( mode=["test"], transform_func=[trans_func_tuple]) src_vocab, tgt_vocab = IWSLT15.get_vocab() bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id test_batch_sampler = SamplerHelper(test_ds).batch(batch_size=batch_size) test_loader = paddle.io.DataLoader( test_ds, batch_sampler=test_batch_sampler, collate_fn=partial( prepare_infer_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return test_loader, len(src_vocab), len(tgt_vocab), bos_id, eos_id
def create_train_loader(args): batch_size = args.batch_size max_len = args.max_len src_vocab, tgt_vocab = IWSLT15.get_vocab() bos_id = src_vocab[src_vocab.bos_token] eos_id = src_vocab[src_vocab.eos_token] pad_id = eos_id train_ds, dev_ds = IWSLT15.get_datasets( mode=["train", "dev"], transform_func=[trans_func_tuple, trans_func_tuple]) key = (lambda x, data_source: len(data_source[x][0])) cut_fn = lambda data: (data[0][:max_len], data[1][:max_len]) train_ds = train_ds.filter( lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn) dev_ds = dev_ds.filter( lambda data: (len(data[0]) > 0 and len(data[1]) > 0)).apply(cut_fn) train_batch_sampler = SamplerHelper(train_ds).shuffle().sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) dev_batch_sampler = SamplerHelper(dev_ds).sort( key=key, buffer_size=batch_size * 20).batch(batch_size=batch_size) train_loader = paddle.io.DataLoader( train_ds, batch_sampler=train_batch_sampler, collate_fn=partial( prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) dev_loader = paddle.io.DataLoader( dev_ds, batch_sampler=dev_batch_sampler, collate_fn=partial( prepare_train_input, bos_id=bos_id, eos_id=eos_id, pad_id=pad_id)) return train_loader, dev_loader, len(src_vocab), len(tgt_vocab), pad_id