def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])
    test_data = DataLoader(preprocess_data['dict']['src'],
                           preprocess_data['dict']['tgt'],
                           src_insts=test_src_insts,
                           cuda=opt.cuda,
                           shuffle=False,
                           batch_size=opt.batch_size)

    translator = Translator(opt)
    translator.model.eval()

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_data,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join(
                        [test_data.tgt_idx2word[idx] for idx in idx_seq])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
示例#2
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                            be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                            decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')  # 有动作就设置为true

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)
    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for hyp_stream in all_hyp:
                for hyp in hyp_stream:
                    pred_sent = ' '.join(
                        [test_loader.dataset.tgt_idx2word[idx] for idx in hyp])
                    f.write(pred_sent + '\n')
    print('[Info] Finished')
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True,
                        help='Path to model .pt file')
    parser.add_argument('-src', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-vocab', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output', default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30,
                        help='Batch size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])
    test_data = DataLoader(
        preprocess_data['dict']['src'],
        preprocess_data['dict']['tgt'],
        src_insts=test_src_insts,
        cuda=opt.cuda,
        shuffle=False,
        batch_size=opt.batch_size)

    translator = Translator(opt)
    translator.model.eval()

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_data, mininterval=2, desc='  - (Test)', leave=False):
            all_hyp, all_scores = translator.translate_batch(batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([test_data.tgt_idx2word[idx] for idx in idx_seq])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-data', required=True)
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    with open(opt.data, "rb") as f:
        data = pickle.load(f)

    test_loader = torch.utils.data.DataLoader(CodeDocstringDatasetPreprocessed(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                              num_workers=0,
                                              batch_size=opt.batch_size,
                                              collate_fn=paired_collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch[:2])
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
示例#5
0
def eval_bleu_score(opt, model, data, device, epoch, split = 'dev'):
    translator = Translator(opt, model, load_from_file = False)
    hyp_file = os.path.join(opt.save_model_dir, 'mypreds_' + split + '_' + str(epoch) + '.hyp')
    outfile = open(hyp_file, 'w')
    for batch in tqdm(data, mininterval=2, desc='  - (Test)', leave=False):
        src_seq, src_pos, tgt_seq, tgt_pos = map(lambda x: x.to(device), batch)
        all_hyp, all_scores = translator.translate_batch(src_seq, src_pos)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred = idx_seq
                out = sp.DecodeIds(pred)
                outfile.write(out + '\n')
    outfile.close()
    os.system("sh calcBLEU.sh " + split + " " + opt.save_model_dir + " " + str(epoch))
示例#6
0
    def attributor_batch_beam(self, training_data, opt):
        def f(x):
            x.to(self.device)
            return x

        translator = Translator(opt)
        for batch in tqdm(training_data,
                          mininterval=2,
                          desc='  - (Attributing)',
                          leave=False):
            src_seq, src_pos, tgt_seq, tgt_pos = map(f, batch)
            all_hyp, all_scores = translator.translate_batch(
                src_seq, src_pos, False)  # translations and Beam search scores
            print(all_hyp)
示例#7
0
文件: translate.py 项目: pmsgd/nl2sql
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('--model', required=True,
                        help='Path to model .pt file')
    parser.add_argument('--src', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('--output', default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('--batch_size', type=int, default=30,
                        help='Batch size')
    parser.add_argument('--no_cuda', action='store_true')
    parser.add_argument('--beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('--n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                            decoded sentences""")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.device = torch.device('cuda' if opt.cuda else 'cpu')

    # load model
    model, model_opt = load_model(opt)
    en_vocab, sql_vocab = load_vocabs(opt)

    # load data
    loader, en_field = load_data(opt, en_vocab)
    sql_field = create_reversible_field(sql_vocab)
    bos_token = sql_field.vocab[Constants.BOS_WORD]
    eos_token = sql_field.vocab[Constants.EOS_WORD]
    pad_token = sql_field.vocab[Constants.PAD_WORD]

    print('[Info] Inference start.')
    translator = Translator(opt, model, model_opt)
    with open(opt.output, 'w') as f:
        with torch.no_grad():
            for batch in tqdm(loader, mininterval=2, desc='  - (Test)', leave=False):
                all_hyp, all_scores = translator.translate_batch(*batch,
                                                                 bos_token=bos_token,
                                                                 eos_token=eos_token,
                                                                 pad_token=pad_token)
                for idx_seqs in all_hyp:
                    pred_line = ' '.join(sql_field.reverse(torch.LongTensor(idx_seqs)))
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
示例#8
0
def decode(model, src_seq, src_pos, ctx_seq, ctx_pos, args, token_len):
    translator = Translator(max_token_seq_len=args.max_token_seq_len,
                            beam_size=10,
                            n_best=1,
                            device=args.device,
                            bad_mask=None,
                            model=model)
    tgt_seq = []
    all_hyp, all_scores = translator.translate_batch(src_seq, src_pos, ctx_seq,
                                                     ctx_pos)
    for idx_seqs in all_hyp:  # batch
        idx_seq = idx_seqs[0]  # n_best=1
        end_pos = len(idx_seq)
        for i in range(len(idx_seq)):
            if idx_seq[i] == Constants.EOS:
                end_pos = i
                break
        # tgt_seq.append([Constants.BOS] + idx_seq[:end_pos][:args.max_word_seq_len] + [Constants.EOS])
        tgt_seq.append(idx_seq[:end_pos][:args.max_word_seq_len])
    batch_seq, batch_pos = collate_fn(tgt_seq, max_len=token_len)
    return batch_seq.to(args.device), batch_pos.to(args.device)
示例#9
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    # pdb.set_trace()
    # (Pdb) print(opt)
    # Namespace(batch_size=30, beam_size=5, cuda=True, model='trained.chkpt',
    #     n_best=1, no_cuda=False, output='pred.txt', src='data/multi30k/test.en.atok',
    #     vocab='data/multi30k.atok.low.pt')

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-data_dir', required=True)
    parser.add_argument('-debug', action='store_true')
    parser.add_argument('-dir_out', default="/home/suster/Apps/out/")
    parser.add_argument(
        "--convert-consts",
        type=str,
        help="conv | our-map | no-our-map | no. \n/"
        "conv-> txt: -; stats: num_sym+ent_sym.\n/"
        "our-map-> txt: num_sym; stats: num_sym(from map)+ent_sym;\n/"
        "no-our-map-> txt: -; stats: num_sym(from map)+ent_sym;\n/"
        "no-> txt: -; stats: -, only ent_sym;\n/"
        "no-ent-> txt: -; stats: -, no ent_sym;\n/")
    parser.add_argument(
        "--label-type-dec",
        type=str,
        default="full-pl",
        help=
        "predicates | predicates-all | predicates-arguments-all | full-pl | full-pl-no-arg-id | full-pl-split | full-pl-split-plc | full-pl-split-stat-dyn. To use with EncDec."
    )
    parser.add_argument('-vocab', required=True)
    #parser.add_argument('-output', default='pred.txt',
    #                    help="""Path to output the predictions (each line will
    #                    be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    args = parser.parse_args()
    args.cuda = not args.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(args.vocab)
    preprocess_settings = preprocess_data['settings']

    if args.convert_consts in {"conv"}:
        assert "nums_mapped" not in args.data_dir
    elif args.convert_consts in {"our-map", "no-our-map", "no", "no-ent"}:
        assert "nums_mapped" in args.data_dir
    else:
        if args.convert_consts is not None:
            raise ValueError
    test_corp = Nlp4plpCorpus(args.data_dir + "test", args.convert_consts)

    if args.debug:
        test_corp.insts = test_corp.insts[:10]
    test_corp.get_labels(label_type=args.label_type_dec)
    test_corp.remove_none_labels()

    # Training set
    test_src_word_insts, test_src_id_insts = prepare_instances(test_corp.insts)
    test_tgt_word_insts, test_tgt_id_insts = prepare_instances(test_corp.insts,
                                                               label=True)
    assert test_src_id_insts == test_tgt_id_insts
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=0,
                                              batch_size=args.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(args)

    i = 0
    preds = []
    golds = []

    for batch in tqdm(test_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred = [
                    test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq
                    if test_loader.dataset.tgt_idx2word[idx] != "</s>"
                ]
                gold = [
                    w for w in test_tgt_word_insts[i]
                    if w not in {"<s>", "</s>"}
                ]
                if args.convert_consts == "no":
                    num2n = None
                else:
                    id = test_src_id_insts[i]
                    assert test_corp.insts[i].id == id
                    num2n = test_corp.insts[i].num2n_map
                pred = final_repl(pred, num2n)
                gold = final_repl(gold, num2n)
                preds.append(pred)
                golds.append(gold)
                i += 1
    acc = accuracy_score(golds, preds)
    print(f"Accuracy: {acc:.3f}")
    print("Saving predictions from the best model:")

    assert len(test_src_id_insts) == len(test_src_word_insts) == len(
        preds) == len(golds)
    f_model = f'{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}'
    dir_out = f"{args.dir_out}log_w{f_model}/"
    print(f"Save preds dir: {dir_out}")
    if not os.path.exists(dir_out):
        os.makedirs(dir_out)
    for (id, gold, pred) in zip(test_src_id_insts, golds, preds):
        f_name_t = os.path.basename(f"{id}.pl_t")
        f_name_p = os.path.basename(f"{id}.pl_p")
        with open(dir_out + f_name_t,
                  "w") as f_out_t, open(dir_out + f_name_p, "w") as f_out_p:
            f_out_t.write(gold)
            f_out_p.write(pred)

    #with open(args.output, 'w') as f:
    #   golds
    #    preds
    #    f.write("PRED: " + pred_line + '\n')
    #    f.write("GOLD: " + gold_line + '\n')

    print('[Info] Finished.')
示例#11
0
		net2.load_state_dict(checkpoint["state_dict_net2"])
		print("restore successfully!")
	else:
		print("fail to restore, path don't exist")
	
	translator = Translator(net2, beam_size=hp.beam_size, max_seq_len=hp.max_seq_len, n_best=hp.n_best)

	print("************************begin infer*********************")
	for imgs_name, imgs, length_imgs, labels, legnth_labels in testLoader:
		imgs = Variable(imgs).cuda()
		length_imgs = Variable(length_imgs).cuda()
      
		enc_img = net1(imgs.float())
		batch_size, channel, height, width = enc_img.shape
		enc_img = enc_img.permute(0, 2, 3, 1).contiguous().view(batch_size, -1, channel)
		batch_pred, batch_prob = translator.translate_batch(enc_img, length_imgs)
        
		label_seq = []
		for seq in labels.data.numpy():
			expression = ""
			for char_idx in seq:
				if char_idx == Constants.EOS:
					break
				else:
					expression += dataset.idx2word.get(char_idx, '')
			label_seq.append(expression)

		pre_seq = []
		for best_pred in batch_pred:
			for seq in best_pred:
				expression = ""
示例#12
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-target',
        required=True,
        help='Target sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    parser.add_argument('-prune', action='store_true')
    parser.add_argument('-prune_alpha', type=float, default=0.1)
    parser.add_argument('-load_mask', type=str, default=None)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']

    refs = read_instances_from_file(opt.target,
                                    preprocess_settings.max_word_seq_len,
                                    preprocess_settings.keep_case)

    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts,
    ),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    preds = []
    preds_text = []

    for batch in tqdm(test_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                sent = ' '.join(
                    [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
                sent = sent.split("</s>")[0].strip()
                sent = sent.replace("▁", " ")
                preds_text.append(sent.strip())
                preds.append(
                    [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
    with open(opt.output, 'w') as f:
        f.write('\n'.join(preds_text))

    from evaluator import BLEUEvaluator
    scorer = BLEUEvaluator()
    length = min(len(preds), len(refs))
    score = scorer.evaluate(refs[:length], preds[:length])
    print(score)
示例#13
0
def main():
    # test_path="../data/tb/test_src.txt"
    # test_path = "../data/qa_data/test_src.txt"
    data_dir = "../data/jd/pure"
    parser = argparse.ArgumentParser(description='main_test.py')
    parser.add_argument('-model_path', default="log/model.ckpt", help='模型路径')
    parser.add_argument('-data_dir', default=data_dir, help='模型路径')
    parser.add_argument('-src', default=data_dir + "/test_src.txt", help='测试集源文件路径')
    parser.add_argument('-data', default=data_dir + "/reader.data", help='训练数据')
    parser.add_argument('-output_dir', default="output", help="输出路径")
    parser.add_argument('-beam_size', type=int, default=10, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('-n_best', type=int, default=3, help="""多句输出""")
    parser.add_argument('-device', action='store_true',
                        default=torch.device('cuda:0' if torch.cuda.is_available() else 'cpu'))
    args = parser.parse_args()
    if not os.path.exists(args.output_dir):
        os.mkdir(args.output_dir)

    print("加载词汇表", os.path.abspath(args.data))
    reader = torch.load(args.data)
    args.max_token_seq_len = reader['settings']["max_token_seq_len"]

    test_src = read_file(path=args.data_dir + "/test_src.txt")
    test_ctx = read_file(path=args.data_dir + "/test_attr.txt")
    test_tgt = read_file(path=args.data_dir + "/test_tgt.txt")
    test_src, test_ctx, _ = digitalize(src=test_src, tgt=None, ctx=test_ctx, max_sent_len=20,
                                       word2idx=reader['dict']['src'], index2freq=reader["dict"]["frequency"], topk=0)

    test_loader = torch.utils.data.DataLoader(
        SeqDataset(
            src_word2idx=reader['dict']['src'],
            tgt_word2idx=reader['dict']['tgt'],
            ctx_word2idx=reader['dict']['ctx'],
            src_insts=test_src,
            ctx_insts=test_ctx),
        num_workers=4,
        batch_size=args.batch_size,
        collate_fn=paired_collate_fn)

    bad_words = ['您', '建', '猜', '查', '吗', '哪', '了', '问', '么', '&', '?']
    bad_idx = [0, 1, 2, 3] + [reader['dict']['src'][w] for w in bad_words]
    # 最后一个批次不等长
    bads = torch.ones((1, len(reader['dict']['tgt'])))
    for i in bad_idx:
        bads[0][i] = 100  # log(prob)<0  分别观察 0.01  100
    args.bad_mask = bads
    # args.bad_mask = None

    checkpoint = torch.load(args.model_path)
    model_opt = checkpoint['settings']
    args.model = ContextTransformer(
        model_opt.ctx_vocab_size,
        model_opt.src_vocab_size,
        model_opt.tgt_vocab_size,
        model_opt.max_token_seq_len,
        tgt_emb_prj_weight_sharing=model_opt.proj_share_weight,
        emb_src_tgt_weight_sharing=model_opt.embs_share_weight,
        d_k=model_opt.d_k,
        d_v=model_opt.d_v,
        d_model=model_opt.d_model,
        d_word_vec=model_opt.d_word_vec,
        d_inner=model_opt.d_inner_hid,
        ct_layers=model_opt.en_layers,
        n_layers=model_opt.n_layers,
        n_head=model_opt.n_head,
        dropout=model_opt.dropout)
    args.model.load_state_dict(checkpoint['model'])
    args.model.word_prob_prj = torch.nn.LogSoftmax(dim=1)
    print('[Info] Trained model state loaded.')

    translator = Translator(max_token_seq_len=args.max_token_seq_len, beam_size=args.beam_size, n_best=args.n_best,
                            device=args.device, bad_mask=args.bad_mask, model=args.model)
    # def __init__(self, max_token_seq_len, beam_size, n_best, device, bad_mask, model):

    # translator.model_opt = checkpoint['settings']
    translator.model.eval()

    path = args.output_dir + '/test_out.txt'
    predicts = [];

    for batch in tqdm(test_loader, mininterval=0.1, desc='  - (Test)', leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:  # batch
            answers = []
            for idx_seq in idx_seqs:  # n_best
                end_pos = len(idx_seq)
                for i in range(len(idx_seq)):
                    if idx_seq[i] == Constants.EOS:
                        end_pos = i
                        break
                pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq[:end_pos]])
                answers.append(pred_line)
            # f.write("\t".join(answers) + '\n')
            # answers_line = "\t".join(answers)
            predicts.append(answers)

    # with open(path, 'w', encoding="utf-8") as f:
    #     f.write("\n".join(predicts))
    # print('[Info] 测试完成,文件写入' + path)

    docBleu = 0.0
    docCdscore = 0.0
    for i in range(len(test_tgt)):
        # for answer in predicts[i].split("\t"):
        print(test_tgt[i] + "----->" + "_".join(predicts[i]))
        # bleu = get_moses_multi_bleu([test_tgt[i]] * 3, predicts[i], lowercase=True)
        bleu_score = bleu([test_tgt[i]], predicts[i]);
        docBleu += bleu_score
        cdscore = cdscore([test_tgt[i]], predicts[i])
        docCdscore += cdscore
        print(" cd_score:",cdscore," bleu:",bleu_score)
    docBleu /= len(test_tgt)
    docCdscore /= len(test_tgt)
    print(" doc bleu-->" + str(docBleu) + "   docCdscore-->" + str(docCdscore))
示例#14
0
class MULTI(nn.Module):
    def __init__(self, config):
        super(MULTI, self).__init__()

        self.config = config

        #self.encoder = GRUEncoder(config.vocab_size, config.encoder_hidden_size)

        # self.encoder = Encoder(
        #     n_src_vocab=config.vocab_size, len_max_seq=300,
        #     d_word_vec=config.embedding_size, n_layers=6, n_head=8, d_k=64, d_v=64, d_model=config.encoder_hidden_size,
        #     d_inner=config.encoder_hidden_size * 4)

        #self.decoder = MultiHeadAttentionGRUDecoder(config.vocab_size, config.decoder_hidden_size, dropout=config.dropout)

        # self.decoder = layers.DecoderRNN(config.vocab_size,
        #                                  config.embedding_size,
        #                                  config.decoder_hidden_size,
        #                                  config.rnncell,
        #                                  config.num_layers,
        #                                  config.dropout,
        #                                  config.word_drop,
        #                                  config.max_unroll,
        #                                  config.sample,
        #                                  config.temperature,
        #                                  config.beam_size)
        #
        # self.context2decoder = layers.FeedForward(config.context_size,
        #                                           config.num_layers * config.decoder_hidden_size,
        #                                           num_layers=1,
        #                                           activation=config.activation)

        #self.tgt_word_prj = nn.Linear(config.decoder_hidden_size, config.vocab_size, bias=False)

        # TODO target weight sharing is disabled!
        self.model = MultiModel(config.vocab_size, config.vocab_size, config.max_history, config.embedding_size, config.decoder_hidden_size,
                                config.decoder_hidden_size * 4, encoder=config.encoder_type,
                                decoder=config.decoder_type, n_layers=config.num_layers, tgt_emb_prj_weight_sharing=False,
                                per_layer_decoder_attention=config.decoder_per_layer_attention)

        self.translator = Translator(model=self.model, beam_size=config.beam_size,
                                     max_seq_len=config.gen_response_len)

        # if config.tie_embedding:
        #     #self.decoder.embedding.weight = self.encoder.src_word_emb.weight
        #     #self.decoder.out.weight = self.decoder.embedding.weight
        #
        #     self.decoder.embedding.weight = self.encoder.src_word_emb.weight
        #     #self.tgt_word_prj.weight = self.decoder.tgt_word_emb.weight
        #     #self.x_logit_scale = (config.decoder_hidden_size ** -0.5)

    def forward(self, histories, segments, responses, decode=False):
        """
        Args:
            input_sentences: (Variable, LongTensor) [num_sentences, seq_len]
            target_sentences: (Variable, LongTensor) [num_sentences, seq_len]
        Return:
            decoder_outputs: (Variable, FloatTensor)
                - train: [batch_size, seq_len, vocab_size]
                - eval: [batch_size, seq_len]
        """

        responses = add_sos(responses)

        history_pos = calc_pos(histories)
        response_pos = calc_pos(responses)

        logits = self.model(histories, history_pos, responses, response_pos, flat_logits=False, src_segs=segments)

        if not decode:
            return logits
        else:
            #TODO go back to topk decoding
            #batch_hyp = self.translator.sample_topk_batch(histories, history_pos, src_segs=segments)
            batch_hyp, batch_scores = self.translator.translate_batch(histories, history_pos, src_segs=segments)
            return [sent[0] for sent in batch_hyp]  # torch.LongTensor(batch_hyp).squeeze(1)
def main():
    """Main Function"""

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-src',
                        required=True,
                        help='Source sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-tgt',
                        required=True,
                        help='Target sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-vocab',
                        required=True,
                        help='Source sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-log',
                        default='translate_log.txt',
                        help="""Path to log the translation(test_inference) 
                        loss""")
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_tgt_word_insts = read_instances_from_file(
        opt.tgt, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])
    test_tgt_insts = convert_instance_to_idx_seq(
        test_tgt_word_insts, preprocess_data['dict']['tgt'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts,
        tgt_insts=test_tgt_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=paired_collate_fn)

    translator = Translator(opt)

    n_word_total = 0
    n_word_correct = 0

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            # all_hyp, all_scores = translator.translate_batch(*batch)
            all_hyp, all_scores = translator.translate_batch(
                batch[0], batch[1])

            # print(all_hyp)
            # print(all_hyp[0])
            # print(len(all_hyp[0]))

            # pad with 0's fit to max_len in insts_group
            src_seqs = batch[0]
            # print(src_seqs.shape)
            tgt_seqs = batch[2]
            # print(tgt_seqs.shape)
            gold = tgt_seqs[:, 1:]
            # print(gold.shape)
            max_len = gold.shape[1]

            pred_seq = []
            for item in all_hyp:
                curr_item = item[0]
                curr_len = len(curr_item)
                # print(curr_len, max_len)
                # print(curr_len)
                if curr_len < max_len:
                    diff = max_len - curr_len
                    curr_item.extend([0] * diff)
                else:  # TODO: why does this case happen?
                    curr_item = curr_item[:max_len]
                pred_seq.append(curr_item)
            pred_seq = torch.LongTensor(np.array(pred_seq))
            pred_seq = pred_seq.view(opt.batch_size * max_len)

            n_correct = cal_performance(pred_seq, gold)

            non_pad_mask = gold.ne(Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct

            # trs_log = "transformer_loss: {} |".format(trs_loss)
            #
            # with open(opt.log, 'a') as log_tf:
            #     log_tf.write(trs_log + '\n')

            count = 0
            for pred_seqs in all_hyp:
                src_seq = src_seqs[count]
                tgt_seq = tgt_seqs[count]
                for pred_seq in pred_seqs:
                    src_line = ' '.join([
                        test_loader.dataset.src_idx2word[idx]
                        for idx in src_seq.data.cpu().numpy()
                    ])
                    tgt_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in tgt_seq.data.cpu().numpy()
                    ])
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in pred_seq
                    ])
                    f.write(
                        "\n ----------------------------------------------------------------------------------------------------------------------------------------------  \n"
                    )
                    f.write("\n [src]  " + src_line + '\n')
                    f.write("\n [tgt]  " + tgt_line + '\n')
                    f.write("\n [pred] " + pred_line + '\n')

                    count += 1

        accuracy = n_word_correct / n_word_total
        accr_log = "accuracy: {} |".format(accuracy)
        # print(accr_log)

        with open(opt.log, 'a') as log_tf:
            log_tf.write(accr_log + '\n')

    print('[Info] Finished.')
示例#16
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=8, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    """
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=preprocess_data['dict']['src'],
            tgt_word2idx=preprocess_data['dict']['tgt'],
            src_insts=test_src_insts),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=collate_fn)
    """

    Dataloader = Loaders()
    Dataloader.get_loaders(opt)
    test_loader = Dataloader.loader['test']

    opt.src_vocab_size = len(Dataloader.frame_vocab)
    opt.tgt_vocab_size = len(Dataloader.story_vocab)

    translator = Translator(opt)

    with open(opt.output, 'w', buffering=1) as f:
        for frame, frame_pos, frame_sen_pos, gt_seqs, _ in tqdm(
                test_loader, mininterval=2, desc='  - (Test)', leave=False):
            all_hyp, all_scores = translator.translate_batch(
                frame, frame_pos, frame_sen_pos)
            for idx_seqs in all_hyp:
                for idx_frame, idx_seq, gt_seq in zip(frame, idx_seqs,
                                                      gt_seqs):
                    f.write('Prediction:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx] for idx in idx_seq
                        if idx != Constants.BOS
                    ])
                    f.write(pred_line + '\n')
                    f.write('Frame:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.frame_vocab.idx2word[idx.item()]
                        for idx in idx_frame if idx != Constants.PAD
                    ])
                    f.write(pred_line + '\n')
                    f.write('Ground Truth:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx.item()]
                        for idx in gt_seq if idx != Constants.PAD
                    ])
                    f.write(pred_line + '\n')
                    f.write(
                        "===============================================\n")
    print('[Info] Finished.')
示例#17
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True,
                        help='Path to model .pt file')
    parser.add_argument('-vocab', required=True,
                        help='Path to vocabulary file')
    parser.add_argument('-output',
                        help="""Path to output the predictions""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    src_line = "Binary files a / build / linux / jre . tgz and b / build / linux / jre . tgz differ <nl>"

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances(
        src_line,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=preprocess_data['dict']['src'],
            tgt_word2idx=preprocess_data['dict']['tgt'],
            src_insts=test_src_insts),
        num_workers=2,
        batch_size=1,
        collate_fn=collate_fn)

    translator = Translator(opt)


    for batch in tqdm(test_loader, mininterval=1, desc='  - (Test)', leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq[:-1]])
            print(pred_line)
    
    sent = src_line.split()
    tgt_sent = pred_line.split()
    
    for layer in range(0, 2):
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        print("Encoder Layer", layer+1)
        for h in range(4):
            print(translator.model.encoder.layer_stack[layer].slf_attn.attn.data.cpu().size())
            draw(translator.model.encoder.layer_stack[layer].slf_attn.attn[h, :, :].data.cpu(), 
                sent, sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Encoder Layer %d.png" % layer)
        
    for layer in range(0, 2):
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        print("Decoder Self Layer", layer+1)
        for h in range(4):
            print(translator.model.decoder.layer_stack[layer].slf_attn.attn.data.cpu().size())
            draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(tgt_sent), :len(tgt_sent)].cpu(), 
                tgt_sent, tgt_sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Decoder Self Layer %d.png" % layer)

        print("Decoder Src Layer", layer+1)
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        for h in range(4):
            draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(sent), :len(tgt_sent)].cpu(), 
                tgt_sent, sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Decoder Src Layer %d.png" % layer)
                    
    print('[Info] Finished.')
示例#18
0
def main():
    '''Main Function'''

    '''
    这个模型是从英语到德语.
    '''









    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=False,
                        help='Path to model .pt file')
    parser.add_argument('-src', required=False,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-vocab', required=False,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output', default='2',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30,
                        help='Batch size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')



    #-vocab data/multi30k.atok.low.pt







    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.cuda=False
    opt.model='trained.chkpt'
    opt.src='1'
    opt.vocab='multi30k.atok.low.pt'
    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)

    tmp1=preprocess_data['dict']['src']
    tmp2=preprocess_data['dict']['tgt']
    with open('55','w')as f:
        f.write(str(tmp1))

    with open('66','w',encoding='utf-8')as f:
        f.write(str(tmp2))





    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=preprocess_data['dict']['src'],
            tgt_word2idx=preprocess_data['dict']['tgt'],
            src_insts=test_src_insts),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in test_loader:
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    print(idx_seq)
                    pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) # 把id转化会text
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
示例#19
0
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq[:-1]
                    ])
                    f.write(pred_line + '\n')

    print('[Info] Finished.')
示例#20
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')
    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=3, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-device', action='store')
    parser.add_argument('-positional',
                        type=str,
                        choices=['Default', 'LRPE', 'LDPE'],
                        default='Default')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    opt.cuda = False
    # Prepare DataLoader

    Dataloader = Loaders()
    Dataloader.get_test_loaders(opt)
    test_loader = Dataloader.loader['vist_term']

    opt.src_vocab_size = len(Dataloader.frame_vocab)
    opt.tgt_vocab_size = len(Dataloader.story_vocab)

    output = json.load(
        open('../data/generated_terms/VIST_test_self_output_diverse.json'))
    #output = json.load(open('../../commen-sense-storytelling/data/remove_bus_test.json'))
    count = 0
    BOS_set = set([2, 3, 4, 5, 6, 7])
    translator = Translator(opt)

    with open("./test/"+opt.output, 'w', buffering=1) as f_pred,\
            open("./test/"+'gt.txt', 'w', buffering=1) as f_gt, open("./test/"+'show.txt', 'w', buffering=1) as f:
        for frame, frame_pos, frame_sen_pos, gt_seqs, _, _ in tqdm(
                test_loader, mininterval=2, desc='  - (Test)', leave=False):
            all_hyp, all_scores = translator.translate_batch(
                frame, frame_pos, frame_sen_pos)
            for idx_seqs, idx_frame, gt_seq in zip(all_hyp, frame, gt_seqs):
                for idx_seq in idx_seqs:

                    f.write('Prediction:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx] for idx in idx_seq
                        if idx not in BOS_set
                    ])
                    f.write(
                        (pred_line + '\n').encode('ascii',
                                                  'ignore').decode('ascii'))
                    f.write('Frame:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.frame_vocab.idx2word[idx.item()]
                        for idx in idx_frame if idx != Constants.PAD
                    ])
                    f.write(
                        (pred_line + '\n').encode('ascii',
                                                  'ignore').decode('ascii'))
                    f.write('Ground Truth:' + '\n')
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx.item()]
                        for idx in gt_seq if idx != Constants.PAD
                    ])
                    f.write(
                        (pred_line + '\n').encode('ascii',
                                                  'ignore').decode('ascii'))
                    f.write(
                        "===============================================\n")

                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx] for idx in idx_seq
                        if idx != Constants.PAD
                    ])
                    f_pred.write(
                        (pred_line + '\n').encode('ascii',
                                                  'ignore').decode('ascii'))
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx.item()]
                        for idx in gt_seq if idx != Constants.PAD
                    ])
                    f_gt.write(
                        (pred_line + '\n').encode('ascii',
                                                  'ignore').decode('ascii'))
                    for _ in range(5):
                        output[count][
                            'predicted_story'] = pred_line = ' '.join([
                                Dataloader.story_vocab.idx2word[idx]
                                for idx in idx_seq if idx not in BOS_set
                            ])
                        count += 1

    print('[Info] Finished.')
    filename = '../data/generated_story/VIST_test_self_output_diverse_noun2_norm_penalty_VISTData_percent.json' + str(
        opt.positional) + '.json'

    json.dump(output, open(filename, 'w'), indent=4)
示例#21
0
    def generate(self,
                 features,
                 json_file,
                 bpm,
                 unique_states,
                 temperature,
                 use_beam_search=False,
                 generate_full_song=False):
        opt = self.opt

        y = features
        y = np.concatenate((np.zeros((y.shape[0], 1)), y), 1)
        y = np.concatenate((y, np.zeros((y.shape[0], 1))), 1)
        if opt.using_bpm_time_division:
            beat_duration = 60 / bpm  #beat duration in seconds
            beat_subdivision = opt.beat_subdivision
            sample_duration = beat_duration * 1 / beat_subdivision  #sample_duration in seconds
        else:
            sample_duration = opt.step_size
        sequence_length_samples = y.shape[1]
        sequence_length = sequence_length_samples * sample_duration

        ## BLOCKS TENSORS ##
        one_hot_states, states, state_times, delta_forward, delta_backward, indices = get_block_sequence_with_deltas(
            json_file,
            sequence_length,
            bpm,
            sample_duration,
            top_k=2000,
            states=unique_states,
            one_hot=True,
            return_state_times=True)
        if not generate_full_song:
            truncated_sequence_length = min(len(states), opt.max_token_seq_len)
        else:
            truncated_sequence_length = len(states)
        indices = indices[:truncated_sequence_length]
        delta_forward = delta_forward[:, :truncated_sequence_length]
        delta_backward = delta_backward[:, :truncated_sequence_length]

        input_forward_deltas = torch.tensor(delta_forward).unsqueeze(0).long()
        input_backward_deltas = torch.tensor(delta_backward).unsqueeze(
            0).long()
        if opt.tgt_vector_input:
            input_block_sequence = torch.tensor(one_hot_states).unsqueeze(
                0).long()
            input_block_deltas = torch.cat([
                input_block_sequence, input_forward_deltas,
                input_backward_deltas
            ], 1)

        y = y[:, indices]
        input_windows = [y]
        song_sequence = torch.tensor(input_windows)
        song_sequence = (song_sequence - song_sequence.mean()
                         ) / torch.abs(song_sequence).max().float()
        if not opt.tgt_vector_input:
            song_sequence = torch.cat([
                song_sequence,
                input_forward_deltas.double(),
                input_backward_deltas.double()
            ], 1)

        src_pos = torch.tensor(np.arange(len(indices))).unsqueeze(0)
        src_mask = torch.tensor(constants.NUM_SPECIAL_STATES *
                                np.ones(len(indices))).unsqueeze(0)

        ## actually generate level ##
        translator = Translator(opt, self)
        translator.model.eval()
        # need to pass to beam .advance, the length of sequence :P ... I think it makes sense
        if opt.tgt_vector_input:
            raise NotImplementedError(
                "Need to implement beam search for Transformer target vector inputs (when we attach deltas to target sequence)"
            )
        else:
            if use_beam_search:
                with torch.no_grad():
                    all_hyp, all_scores = translator.translate_batch(
                        song_sequence.permute(0, 2, 1).float(), src_pos,
                        src_mask, truncated_sequence_length)
                    generated_sequence = all_hyp[0][0]
            else:
                with torch.no_grad():
                    generated_sequence = translator.sample_translation(
                        song_sequence.permute(0, 2, 1).float(), src_pos,
                        src_mask, truncated_sequence_length, temperature)
        # return state_times, all_hyp[0] # we are for now only supporting single batch generation..
        return state_times, generated_sequence
示例#22
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')
    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=3, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-device', action='store')
    parser.add_argument('-positional',
                        type=str,
                        choices=['Default', 'LRPE', 'LDPE'],
                        default='Default')
    parser.add_argument('-insert',
                        required=True,
                        type=int,
                        help="Term's insert number")
    parser.add_argument('-relation',
                        required=True,
                        type=str,
                        help="relation file")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    print('opt.insert', opt.insert)
    print('opt.relation', opt.relation)
    Dataloader = Loaders()
    Dataloader.get_test_add_path_loaders(opt)
    test_loader = Dataloader.loader['add_window_termset']

    opt.src_vocab_size = len(Dataloader.frame_vocab)
    opt.tgt_vocab_size = len(Dataloader.story_vocab)

    output = json.load(
        open('../data/generated_terms/VIST_test_self_output_diverse.json'))
    count = 0
    BOS_set = set([2, 3, 4, 5, 6, 7])
    BOSs_re = "|".join(
        [Dataloader.story_vocab.idx2word[idx] for idx in Constants.BOSs])
    translator = Translator(opt)

    with open("./test/"+opt.output, 'w', buffering=1) as f_pred,\
            open("./test/"+'gt.txt', 'w', buffering=1) as f_gt, open("./test/"+'show.txt', 'w', buffering=1) as f:
        for frame, frame_pos, frame_sen_pos, gt_seqs, _, _ in tqdm(
                test_loader, mininterval=2, desc='  - (Test)', leave=False):
            all_hyp, all_scores = translator.translate_batch(
                frame, frame_pos, frame_sen_pos)
            print('frame', frame)
            print('frame_sen_pos', frame_sen_pos)
            print('all_hyp', all_hyp)
            window_stories = []
            for idx_seqs, idx_frame, gt_seq in zip(all_hyp, frame, gt_seqs):
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        Dataloader.story_vocab.idx2word[idx] for idx in idx_seq
                        if idx != Constants.EOS
                    ])
                    window_stories.append(re.split(BOSs_re, pred_line))
            tmp_added_story = window_stories[0]
            if len(window_stories) > 1:
                for i in range(1, len(window_stories)):
                    tmp_added_story.append(window_stories[i][-1])
            tmp_added_story = " ".join(tmp_added_story)
            for i in range(count * 5, count * 5 + 5):
                output[i]['add_one_path_story'] = re.sub(
                    ' +', ' ', tmp_added_story)
            count += 1
    print('[Info] Finished.')

    filename = '../data/generated_story/VIST_test_self_output_diverse_add_highest_one_path_noun' + str(
        opt.insert + 1) + str(
            opt.relation) + '_norm_penalty_coor_VISTdataset_percent_' + str(
                opt.positional) + '.json'

    json.dump(output, open(filename, 'w'), indent=4)
示例#23
0
def main():
    """Main Function"""

    parser = argparse.ArgumentParser(description="translate.py")

    parser.add_argument("-model", required=True, help="Path to model .pt file")
    parser.add_argument(
        "-src",
        required=True,
        help="Source sequence to decode (one line per sequence)")
    parser.add_argument(
        "-vocab",
        required=True,
        help="Source sequence to decode (one line per sequence)")
    parser.add_argument("-output",
                        default="pred.txt",
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument("-beam_size", type=int, default=5, help="Beam size")
    parser.add_argument("-batch_size", type=int, default=30, help="Batch size")
    parser.add_argument("-n_best",
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument("-no_cuda", action="store_true")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data["settings"]
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data["dict"]["src"])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data["dict"]["src"],
        tgt_word2idx=preprocess_data["dict"]["tgt"],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, "w") as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc="  - (Test)",
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = " ".join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + "\n")
    print("[Info] Finished.")
示例#24
0
''' Translate input text with trained model. '''
示例#25
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    '''
    parser.add_argument('-src', required=True,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-vocab', required=True,
                        help='Source sequence to decode (one line per sequence)')
    '''
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=100, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=1, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader

    test_data = DataLoader(use_valid=True,
                           batch_size=opt.batch_size,
                           cuda=opt.cuda)

    translator = Translator(opt)
    translator.model.eval()

    numuser = test_data.user_size

    num_right = 0
    num_total = 0

    avgF1 = 0
    avgPre = 0
    avgRec = 0

    avgF1_long = 0
    avgPre_long = 0
    avgRec_long = 0

    avgF1_short = 0
    avgPre_short = 0
    avgRec_short = 0
    numseq = 0  # number of test seqs

    # for micro pre rec f1
    right = 0.
    pred = 0.
    total = 0.
    right_long = 0.
    pred_long = 0.
    total_long = 0.
    right_short = 0.
    pred_short = 0.
    total_short = 0.

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_data,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_samples = translator.translate_batch(batch).data

            for bid in range(batch.size(0)):
                numseq += 1.0

                ground_truth = np.zeros([numuser])
                num_ground_truth = 0
                for user in batch.data[bid][1:-1]:
                    if user == Constants.EOS or user == Constants.PAD:
                        break
                    ground_truth[user] = 1.0
                    num_ground_truth += 1

                pred_cnt = np.zeros([numuser])
                for beid in range(opt.beam_size):
                    for pred_uid in all_samples[bid, beid,
                                                1:num_ground_truth + 1]:
                        if pred_uid == Constants.EOS:
                            break
                        else:
                            pred_cnt[pred_uid] += 1.0 / opt.beam_size

                F1, pre, rec = getF1(ground_truth, pred_cnt)
                avgF1 += F1
                avgPre += pre
                avgRec += rec
                right += np.dot(ground_truth, pred_cnt)
                pred += np.sum(pred_cnt)
                total += np.sum(ground_truth)

                # for short user
                ground_truth = np.zeros([numuser])
                num_ground_truth = 0
                for user in batch.data[bid][1:-1]:
                    if user == Constants.EOS or user == Constants.PAD:
                        break
                    ground_truth[user] = 1.0
                    num_ground_truth += 1
                    if num_ground_truth >= 5:
                        break

                pred_cnt = np.zeros([numuser])
                for beid in range(opt.beam_size):
                    #total += len(ground_truth)
                    for pred_uid in all_samples[bid, beid,
                                                1:num_ground_truth + 1]:
                        if pred_uid == Constants.EOS:
                            break
                            #continue
                        else:
                            pred_cnt[pred_uid] += 1.0 / opt.beam_size

                F1, pre, rec = getF1(ground_truth, pred_cnt)
                avgF1_short += F1
                avgPre_short += pre
                avgRec_short += rec
                right_short += np.dot(ground_truth, pred_cnt)
                pred_short += np.sum(pred_cnt)
                total_short += np.sum(ground_truth)

    print('[Info] Finished.')
    print('Macro')
    print(avgF1 / numseq)
    print(avgPre / numseq)
    print(avgRec / numseq)
    print('Results for the first no more than 5 predictions')
    print(avgF1_short / numseq)
    print(avgPre_short / numseq)
    print(avgRec_short / numseq)

    print('Micro')
    pmi = right / pred
    rmi = right / total
    print(2 * pmi * rmi / (pmi + rmi))
    print(pmi)
    print(rmi)

    print('Results for the first no more than 5 predictions')
    pmi_long = right_short / pred_short
    rmi_long = right_short / total_short
    print(2 * pmi_long * rmi_long / (pmi_long + rmi_long))
    print(pmi_long)
    print(rmi_long)
示例#26
0
def main():

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model',
                        default='trained.chkpt',
                        help='Path to model .pt file')
    parser.add_argument(
        '-src',
        default='data/multi30k/test.en.atok',
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-ctx',
        required=False,
        default="",
        help='Context sequence to decode (one line per sequence)')
    parser.add_argument('-vocab',
                        default='data/multi30k.atok.low.pt',
                        help='Data that contains the source vocabulary')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_false')
    parser.add_argument('-max_token_seq_len', type=int, default=100)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']

    test_src_word_insts = read_instances_from_file(
        opt.src, opt.max_token_seq_len, preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    if opt.ctx:
        from preprocess_ctx import read_instances_from_file as read_instances_from_file_ctx
        test_ctx_word_insts = read_instances_from_file_ctx(
            opt.ctx,
            opt.max_token_seq_len,
            preprocess_settings.keep_case,
            is_ctx=True)
        test_ctx_insts = convert_instance_to_idx_seq(
            test_ctx_word_insts, preprocess_data['dict']['src'])

    test_data = DataLoader(preprocess_data['dict']['src'],
                           preprocess_data['dict']['tgt'],
                           src_insts=test_src_insts,
                           ctx_insts=(test_ctx_insts if opt.ctx else None),
                           cuda=opt.cuda,
                           shuffle=False,
                           batch_size=opt.batch_size,
                           is_train=False)

    translator = Translator(opt)
    translator.model.eval()

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_data,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            print(---------1111111111)
            all_hyp, all_scores = translator.translate_batch(*batch)
            print(---------2222222222)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    if idx_seq[-1] == 3:  # if last word is EOS
                        idx_seq = idx_seq[:-1]
                    pred_line = ' '.join(
                        [test_data.tgt_idx2word[int(idx)] for idx in idx_seq])
                    f.write(pred_line + '\n')
            print("end")

    print('[Info] Finished.')