Beispiel #1
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    # num_workers表示线程数量
    # collate_fn,是用来处理不同情况下的输入dataset的封装,
    # 一般采用默认即可,除非你自定义的数据读取输出非常少见
    # 跳过collate_fn
    train_loader = torch.utils.data.DataLoader(
        # TranslateionDataset参数中前两个是索引,后两个是数据, 其它的都不重要,重要的是这个类
        # 必须实现Dataset的接口,即__len__方法与__getitem__方法
        # len方法用来获取数据集长度即src_insts长度, getitem(i), 用来获取第i个数据,
        # 即(src_insts[i], tgt_insts[i])
        # 这里写paired_collate_fn函数的原因应该是getitem(index)方法返回的
        # 不是单个数据,而是一个元组
        # shuffle : set to True to have the data reshuffled at every epoch
        # shuffle使得每轮训练取得的batch顺序不同
        TranslationDataset(src_word2idx=data['dict']['src'],
                           tgt_word2idx=data['dict']['tgt'],
                           src_insts=data['train']['src'],
                           tgt_insts=data['train']['tgt']),
        # load data用到的线程数为2
        num_workers=2,
        # batch_size此处为64
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #2
0
def prepare_dataloaders(data, opt, distributed):
    # ========= Preparing DataLoader =========#
    train_dataset = TranslationDataset(src_word2idx=data['dict']['src'],
                                       tgt_word2idx=data['dict']['tgt'],
                                       src_insts=data['train']['src'],
                                       tgt_insts=data['train']['tgt'])
    if distributed:
        train_sampler = torch.utils.data.distributed.DistributedSampler(
            train_dataset)
    else:
        train_sampler = None

    train_loader = torch.utils.data.DataLoader(train_dataset,
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               shuffle=train_sampler is None,
                                               sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    return train_loader, valid_loader
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=int(opt.num_workers),
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        pin_memory=True,
        shuffle=False)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt']),
        num_workers=int(opt.num_workers),
        batch_size=opt.batch_size,
        pin_memory=True,
        collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #4
0
def prepare_dataloaders(data, opt):
    print(data["settings"])

    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data["dict"]["src"],
            tgt_word2idx=data["dict"]["tgt"],
            src_insts=data["train"]["src"],
            tgt_insts=data["train"]["tgt"],
        ),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True,
    )

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data["dict"]["src"],
            tgt_word2idx=data["dict"]["tgt"],
            src_insts=data["valid"]["src"],
            tgt_insts=data["valid"]["tgt"],
        ),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
    )

    src_vocab_size = train_loader.dataset.src_vocab_size
    trg_vocab_size = train_loader.dataset.tgt_vocab_size

    src_idx2word = {idx: word for word, idx in data['dict']['src'].items()}
    trg_idx2word = {idx: word for word, idx in data['dict']['tgt'].items()}

    return train_loader, valid_loader, src_vocab_size, trg_vocab_size, src_idx2word, trg_idx2word
Beispiel #5
0
def prepare_dataloaders(data, opt):

    validation_split = 0.1
    shuffle_dataset = True
    random_seed = 42

    initDataset = TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt'],
        sp_insts=data['train']['sp'])

    # Creating data indices for training and validation splits:
    dataset_size = len(initDataset)
    indices = list(range(dataset_size))
    split = int(np.floor(validation_split * dataset_size))
    if shuffle_dataset:
        np.random.seed(random_seed)
        np.random.shuffle(indices)
    train_indices, val_indices = indices[split:], indices[:split]

    # Creating PT data samplers and loaders:
    train_sampler = SubsetRandomSampler(train_indices)
    valid_sampler = SubsetRandomSampler(val_indices)


# ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        initDataset,
        num_workers=4,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        sampler=train_sampler)

    valid_loader = torch.utils.data.DataLoader(
        initDataset,
        num_workers=4,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        sampler=valid_sampler)

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt'],
            sp_insts=data['valid']['sp']
        ),
        num_workers=4,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn)

    return train_loader, valid_loader, test_loader
Beispiel #6
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    # for j in train_loader:
    #     print(">>>>>>>>>>>>>>>>",len(j), j[0].shape, j[1].shape) >>>>>>>>>>>>>>>> 4 torch.Size([64, 51]) torch.Size([64, 51])
    #     break

    # (tensor([[ 2, 24,  1,  ...,  0,  0,  0],
    #     [ 2, 20,  1,  ...,  0,  0,  0],
    #     [ 2,  1,  1,  ...,  0,  0,  0],
    #     ...,
    #     [ 2,  1,  1,  ...,  0,  0,  0],
    #     [ 2, 33, 26,  ...,  0,  0,  0],
    #     [ 2, 13, 25,  ...,  0,  0,  0]]), tensor([[1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     ...,
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0]]), tensor([[ 2,  1, 33,  ...,  0,  0,  0],
    #     [ 2,  1,  1,  ...,  0,  0,  0],
    #     [ 2,  1,  1,  ...,  0,  0,  0],
    #     ...,
    #     [ 2, 24, 34,  ...,  0,  0,  0],
    #     [ 2, 33,  1,  ...,  0,  0,  0],
    #     [ 2, 14,  5,  ...,  0,  0,  0]]), tensor([[1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     ...,
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0],
    #     [1, 2, 3,  ..., 0, 0, 0]]))

    return train_loader, valid_loader
def test(opt):
    """ Functions to test the model and implement machine translation"""
    # Prepare DataLoader
    preprocess_data = t.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = t.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                          num_workers=2,
                                          batch_size=opt.batch_size,
                                          collate_fn=collate_fn)

    translator = TransformerTranslator(opt)

    with open(opt.output, 'w', encoding='utf-8') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
Beispiel #8
0
def main():
    parser = argparse.ArgumentParser()

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                            be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                            decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')  # 有动作就设置为true

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)
    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for hyp_stream in all_hyp:
                for hyp in hyp_stream:
                    pred_sent = ' '.join(
                        [test_loader.dataset.tgt_idx2word[idx] for idx in hyp])
                    f.write(pred_sent + '\n')
    print('[Info] Finished')
Beispiel #9
0
def prepare_dataloaders(data, mined_data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['valid']['src'],
            tgt_insts=data['valid']['tgt']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn)

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['test']['src'],
            tgt_insts=data['test']['tgt']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn)

    mined_loader = torch.utils.data.DataLoader(
            TranslationDataset(
                src_word2idx=mined_data['dict']['src'],
                tgt_word2idx=mined_data['dict']['tgt'],
                src_insts=mined_data['train']['src'],
                tgt_insts=mined_data['train']['tgt']),
            num_workers=2,
            batch_size=opt.batch_size,
            collate_fn=paired_collate_fn)

    return train_loader, valid_loader, test_loader, mined_loader
Beispiel #10
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict'], tgt_word2idx=data['dict'],  # same for language modelling
            src_insts=data['train']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict'], tgt_word2idx=data['dict'],  # same word2idx for language modelling
            src_insts=data['train']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=collate_fn)
    return train_loader, valid_loader
def prepare_dataloaders(data, opt):
    # 将数据进行处理
    print("整合数据...")
    src_pre = data['train']['src']
    tgt_pre = data['train']['tgt']
    all_data = list(zip(src_pre, tgt_pre))

    print("sample 数据...")
    sampler = BatchSampler(SequentialSampler(all_data),
                           batch_size=opt.batch_size,
                           drop_last=False)
    index = [s for s in sampler]
    random.shuffle(index)
    index = list(itertools.chain.from_iterable(index))

    print("重新赋值...")
    src = []
    tgt = []
    for i in index:
        src.append(src_pre[i])
        tgt.append(tgt_pre[i])
    data['train']['src'] = src
    data['train']['tgt'] = tgt
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                               num_workers=0,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               shuffle=False)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=0,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #12
0
def prepare_dataloaders(data, opt):
    train_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)

    return train_loader, valid_loader
Beispiel #13
0
def prepare_dataloaders(data, opt):  #pass
    #把一个dataset封装进prepare_dataloaders里面,dataset 有seq(bh, lens)  有word2idx 有idx2word
    train_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #14
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(  # 跟平时处理一样,还是要自己定义一个数据集的类.再用dataloader来加载.
        TranslationDataset(  # 把train数据放入数据集中.
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=data['train']['src'],
            tgt_insts=data['train']['tgt']),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)
    # 这个数据集只有train 和valid 没有test
    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #15
0
def prepare_dataloaders(data, opt):
    ''' Prepare Pytorch dataloaders '''
    train_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['train']['src'],
        tgt_insts=data['train']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               drop_last=False,
                                               shuffle=True)

    valid_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=data['dict']['src'],
        tgt_word2idx=data['dict']['tgt'],
        src_insts=data['valid']['src'],
        tgt_insts=data['valid']['tgt']),
                                               num_workers=2,
                                               batch_size=opt.batch_size,
                                               collate_fn=paired_collate_fn,
                                               drop_last=False)
    return train_loader, valid_loader
Beispiel #16
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#
    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data["dict"]["src"],
            tgt_word2idx=data["dict"]["tgt"],
            src_insts=data["train"]["src"],
            tgt_insts=data["train"]["tgt"]),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data["dict"]["src"],
            tgt_word2idx=data["dict"]["tgt"],
            src_insts=data["valid"]["src"],
            tgt_insts=data["valid"]["tgt"]),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=paired_collate_fn)
    return train_loader, valid_loader
Beispiel #17
0
def evaluateAndShowAttention(in_s, seq2seq, in_lang, out_lang, out_file):
    seq2seq.eval()
    src = TranslationDataset.to_ids(in_s, in_lang) + [EOS_idx]
    src_len = len(src)
    src = torch.LongTensor(src).view(1, -1).cuda()
    src_len = torch.tensor([src_len])
    dec_outs, attn_ws = seq2seq.generate(src, src_len)
    topi = dec_outs.topk(1)[1]  # [1, max_len, 1]
    out_words = idx2words(topi.squeeze(), out_lang)

    logger.info("input = {}".format(in_s))
    logger.info("output = {}".format(' '.join(out_words)))
    attn_ws = attn_ws.squeeze().detach().cpu()[:len(out_words)]
    image = showAttention(in_s, out_words, attn_ws, out_file)
    return attn_ws, image
Beispiel #18
0
def main():
    """Main Function"""

    parser = argparse.ArgumentParser(description="translate.py")

    parser.add_argument("-model", required=True, help="Path to model .pt file")
    parser.add_argument(
        "-src",
        required=True,
        help="Source sequence to decode (one line per sequence)")
    parser.add_argument(
        "-vocab",
        required=True,
        help="Source sequence to decode (one line per sequence)")
    parser.add_argument("-output",
                        default="pred.txt",
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument("-beam_size", type=int, default=5, help="Beam size")
    parser.add_argument("-batch_size", type=int, default=30, help="Batch size")
    parser.add_argument("-n_best",
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument("-no_cuda", action="store_true")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data["settings"]
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data["dict"]["src"])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data["dict"]["src"],
        tgt_word2idx=preprocess_data["dict"]["tgt"],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, "w") as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc="  - (Test)",
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = " ".join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + "\n")
    print("[Info] Finished.")
Beispiel #19
0
def main():
    ''' Main function '''
    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model',
                        required=True,
                        help='Path to model .chkpt file')
    parser.add_argument('-test_file',
                        required=True,
                        help='Test pickle file for validation')
    parser.add_argument(
        '-output',
        default='outputs.txt',
        help=
        'Path to output the predictions (each line will be the decoded sequence'
    )
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=16, help='Batch size')
    parser.add_argument(
        '-n_best',
        type=int,
        default=1,
        help='If verbose is set, will output the n_best decoded sentences')
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    #- Prepare Translator
    translator = Translator(opt)
    print('[Info] Model opts: {}'.format(translator.model_opt))

    #- Prepare DataLoader
    test_data = torch.load(opt.test_file)

    test_src_insts = test_data['test']['src']
    test_tgt_insts = test_data['test']['tgt']

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=test_data['dict']['src'],
        tgt_word2idx=test_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              drop_last=True,
                                              collate_fn=collate_fn)

    print('[Info] Evaluate on test set.')
    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Testing)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(
                *batch)  # structure: List[batch, seq, pos]
            for inst in all_hyp:
                f.write('[')
                for seq in inst:
                    seq = seq[0]
                    pred_seq = ' '.join([
                        test_loader.dataset.tgt_idx2word[word] for word in seq
                    ])
                    f.write('\t' + pred_seq + '\n')
                f.write(']\n')
    print('[Info] Finished.')
Beispiel #20
0
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-data_dir', required=True)
    parser.add_argument('-debug', action='store_true')
    parser.add_argument('-dir_out', default="/home/suster/Apps/out/")
    parser.add_argument(
        "--convert-consts",
        type=str,
        help="conv | our-map | no-our-map | no. \n/"
        "conv-> txt: -; stats: num_sym+ent_sym.\n/"
        "our-map-> txt: num_sym; stats: num_sym(from map)+ent_sym;\n/"
        "no-our-map-> txt: -; stats: num_sym(from map)+ent_sym;\n/"
        "no-> txt: -; stats: -, only ent_sym;\n/"
        "no-ent-> txt: -; stats: -, no ent_sym;\n/")
    parser.add_argument(
        "--label-type-dec",
        type=str,
        default="full-pl",
        help=
        "predicates | predicates-all | predicates-arguments-all | full-pl | full-pl-no-arg-id | full-pl-split | full-pl-split-plc | full-pl-split-stat-dyn. To use with EncDec."
    )
    parser.add_argument('-vocab', required=True)
    #parser.add_argument('-output', default='pred.txt',
    #                    help="""Path to output the predictions (each line will
    #                    be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    args = parser.parse_args()
    args.cuda = not args.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(args.vocab)
    preprocess_settings = preprocess_data['settings']

    if args.convert_consts in {"conv"}:
        assert "nums_mapped" not in args.data_dir
    elif args.convert_consts in {"our-map", "no-our-map", "no", "no-ent"}:
        assert "nums_mapped" in args.data_dir
    else:
        if args.convert_consts is not None:
            raise ValueError
    test_corp = Nlp4plpCorpus(args.data_dir + "test", args.convert_consts)

    if args.debug:
        test_corp.insts = test_corp.insts[:10]
    test_corp.get_labels(label_type=args.label_type_dec)
    test_corp.remove_none_labels()

    # Training set
    test_src_word_insts, test_src_id_insts = prepare_instances(test_corp.insts)
    test_tgt_word_insts, test_tgt_id_insts = prepare_instances(test_corp.insts,
                                                               label=True)
    assert test_src_id_insts == test_tgt_id_insts
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=0,
                                              batch_size=args.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(args)

    i = 0
    preds = []
    golds = []

    for batch in tqdm(test_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred = [
                    test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq
                    if test_loader.dataset.tgt_idx2word[idx] != "</s>"
                ]
                gold = [
                    w for w in test_tgt_word_insts[i]
                    if w not in {"<s>", "</s>"}
                ]
                if args.convert_consts == "no":
                    num2n = None
                else:
                    id = test_src_id_insts[i]
                    assert test_corp.insts[i].id == id
                    num2n = test_corp.insts[i].num2n_map
                pred = final_repl(pred, num2n)
                gold = final_repl(gold, num2n)
                preds.append(pred)
                golds.append(gold)
                i += 1
    acc = accuracy_score(golds, preds)
    print(f"Accuracy: {acc:.3f}")
    print("Saving predictions from the best model:")

    assert len(test_src_id_insts) == len(test_src_word_insts) == len(
        preds) == len(golds)
    f_model = f'{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}'
    dir_out = f"{args.dir_out}log_w{f_model}/"
    print(f"Save preds dir: {dir_out}")
    if not os.path.exists(dir_out):
        os.makedirs(dir_out)
    for (id, gold, pred) in zip(test_src_id_insts, golds, preds):
        f_name_t = os.path.basename(f"{id}.pl_t")
        f_name_p = os.path.basename(f"{id}.pl_p")
        with open(dir_out + f_name_t,
                  "w") as f_out_t, open(dir_out + f_name_p, "w") as f_out_p:
            f_out_t.write(gold)
            f_out_p.write(pred)

    #with open(args.output, 'w') as f:
    #   golds
    #    preds
    #    f.write("PRED: " + pred_line + '\n')
    #    f.write("GOLD: " + gold_line + '\n')

    print('[Info] Finished.')
Beispiel #22
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='predict.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    # parser.add_argument('-src', required=True,
    #                     help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-data', required=True, help='preprocessed data file')
    parser.add_argument(
        '-original_data',
        default=config.FORMATTED_DATA,
        help='original data showing original text and equations')
    parser.add_argument(
        '-vocab',
        default=None,
        help=
        'data file for vocabulary. if not specified (default), use the one in -data'
    )
    parser.add_argument(
        '-split',
        type=float,
        default=0.8,
        help='proprotion of training data. the rest is test data.')
    parser.add_argument(
        '-offset',
        type=float,
        default=0,
        help="determin starting index of training set, for cross validation")
    parser.add_argument(
        '-output',
        default='pred.json',
        help=
        """Path to output the predictions (each line will be the decoded sequence"""
    )
    parser.add_argument('-beam_size', type=int, default=10, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=64, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-reset_num',
                        default=False,
                        action='store_true',
                        help='replace number symbols with real numbers')
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    print(opt)

    # Prepare DataLoader
    preprocess_data = torch.load(opt.data)
    if opt.original_data is not None:
        formmated_data = json.load(open(opt.original_data))
        formmated_map = {}
        for d in formmated_data:
            formmated_map[d['id']] = d

    N = preprocess_data['settings']['n_instances']
    train_len = int(N * opt.split)
    start_idx = int(opt.offset * N)  # start location of training data
    print("Data split: {}".format(opt.split))
    print("Training starts at: {} out of {} instances".format(start_idx, N))

    if start_idx + train_len < N:
        valid_src_insts = preprocess_data['src'][
            start_idx + train_len:] + preprocess_data['src'][:start_idx]
        valid_tgt_insts = preprocess_data['tgt'][
            start_idx + train_len:] + preprocess_data['tgt'][:start_idx]
    else:
        valid_len = N - train_len
        valid_start_idx = start_idx - valid_len

        valid_src_insts = preprocess_data['src'][valid_start_idx:start_idx]
        valid_tgt_insts = preprocess_data['tgt'][valid_start_idx:start_idx]

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=valid_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size)
    # collate_fn=collate_fn)
    test_loader.collate_fn = test_loader.dataset.collate_fn

    tgt_insts = valid_tgt_insts
    block_list = [preprocess_data['dict']['tgt'][UNK_WORD]]

    translator = Translator(opt)
    # translator = NTMTranslator(opt)
    translator.model.eval()

    output = []
    n = 0
    for batch in tqdm(test_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):
        with torch.no_grad():
            all_hyp_list, all_score_list = translator.translate_batch(
                *batch, block_list=block_list)
        for i, idx_seqs in enumerate(
                all_hyp_list[0]):  # loop over instances in batch
            scores = all_score_list[0][i]
            if translator.opt.bi:  # bidirectional
                idx_seqs_reverse = all_hyp_list[1][i]
                scores_reverse = all_score_list[1][i]

            for j, idx_seq in enumerate(idx_seqs):  # loop over n_best results
                d = {}
                question_id = preprocess_data['idx2id'][(n + train_len +
                                                         start_idx) % N]

                pred_line = ''.join(
                    [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
                score = scores[j]
                if translator.opt.bi:
                    idx_seq_reverse = idx_seqs_reverse[j]
                    score_reverse = scores_reverse[j]
                    idx_seq_reverse.reverse()
                    pred_line_reverse = ''.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq_reverse
                    ])

                src_idx_seq = test_loader.dataset[n]  # truth
                src_text = ' '.join([
                    test_loader.dataset.src_idx2word[idx]
                    for idx in src_idx_seq
                ])
                tgt_text = ''.join([
                    test_loader.dataset.tgt_idx2word[idx]
                    for idx in tgt_insts[n]
                ])
                if opt.reset_num:
                    src_text = reset_numbers(
                        src_text,
                        preprocess_data['numbers'][(n + train_len + start_idx)
                                                   % N])
                    # tgt_text = reset_numbers(tgt_text, preprocess_data['numbers'][n + train_len])
                    tgt_text = ';'.join(
                        formmated_map[question_id]['equations'])

                    pred_line = reset_numbers(
                        pred_line,
                        preprocess_data['numbers'][(n + train_len + start_idx)
                                                   % N],
                        try_similar=True)
                    if translator.opt.bi:
                        pred_line_reverse = reset_numbers(
                            pred_line_reverse,
                            preprocess_data['numbers'][(n + train_len +
                                                        start_idx) % N],
                            try_similar=True)
                        # print(pred_line, tgt_text)
                        # print(pred_line_reverse, tgt_text, '\n')

                d['question'] = src_text
                d['ans'] = preprocess_data['ans'][(n + train_len + start_idx) %
                                                  N]
                d['id'] = question_id
                d['equation'] = tgt_text
                d['pred'] = (pred_line.replace('</s>',
                                               ''), round(score.item(), 3))
                if translator.opt.bi:
                    d['pred_2'] = (pred_line_reverse.replace('</s>', ''),
                                   round(score_reverse.item(), 3))

                output.append(d)
            n += 1

    with open(opt.output, 'w') as f:
        json.dump(output, f, indent=2)
    print('[Info] Finished.')
Beispiel #23
0
def main():
    '''Main Function'''

    '''
    这个模型是从英语到德语.
    '''









    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=False,
                        help='Path to model .pt file')
    parser.add_argument('-src', required=False,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-vocab', required=False,
                        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output', default='2',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30,
                        help='Batch size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')



    #-vocab data/multi30k.atok.low.pt







    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.cuda=False
    opt.model='trained.chkpt'
    opt.src='1'
    opt.vocab='multi30k.atok.low.pt'
    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)

    tmp1=preprocess_data['dict']['src']
    tmp2=preprocess_data['dict']['tgt']
    with open('55','w')as f:
        f.write(str(tmp1))

    with open('66','w',encoding='utf-8')as f:
        f.write(str(tmp2))





    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=preprocess_data['dict']['src'],
            tgt_word2idx=preprocess_data['dict']['tgt'],
            src_insts=test_src_insts),
        num_workers=2,
        batch_size=opt.batch_size,
        collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in test_loader:
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    print(idx_seq)
                    pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) # 把id转化会text
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
Beispiel #24
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True,
                        help='Path to model .pt file')
    parser.add_argument('-vocab', required=True,
                        help='Path to vocabulary file')
    parser.add_argument('-output',
                        help="""Path to output the predictions""")
    parser.add_argument('-beam_size', type=int, default=5,
                        help='Beam size')
    parser.add_argument('-n_best', type=int, default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    src_line = "Binary files a / build / linux / jre . tgz and b / build / linux / jre . tgz differ <nl>"

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances(
        src_line,
        preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=preprocess_data['dict']['src'],
            tgt_word2idx=preprocess_data['dict']['tgt'],
            src_insts=test_src_insts),
        num_workers=2,
        batch_size=1,
        collate_fn=collate_fn)

    translator = Translator(opt)


    for batch in tqdm(test_loader, mininterval=1, desc='  - (Test)', leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq[:-1]])
            print(pred_line)
    
    sent = src_line.split()
    tgt_sent = pred_line.split()
    
    for layer in range(0, 2):
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        print("Encoder Layer", layer+1)
        for h in range(4):
            print(translator.model.encoder.layer_stack[layer].slf_attn.attn.data.cpu().size())
            draw(translator.model.encoder.layer_stack[layer].slf_attn.attn[h, :, :].data.cpu(), 
                sent, sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Encoder Layer %d.png" % layer)
        
    for layer in range(0, 2):
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        print("Decoder Self Layer", layer+1)
        for h in range(4):
            print(translator.model.decoder.layer_stack[layer].slf_attn.attn.data.cpu().size())
            draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(tgt_sent), :len(tgt_sent)].cpu(), 
                tgt_sent, tgt_sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Decoder Self Layer %d.png" % layer)

        print("Decoder Src Layer", layer+1)
        fig, axs = plt.subplots(1,4, figsize=(20, 10))
        for h in range(4):
            draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(sent), :len(tgt_sent)].cpu(), 
                tgt_sent, sent if h ==0 else [], ax=axs[h])
        plt.savefig(opt.output+"Decoder Src Layer %d.png" % layer)
                    
    print('[Info] Finished.')
Beispiel #25
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-target',
        required=True,
        help='Target sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    parser.add_argument('-prune', action='store_true')
    parser.add_argument('-prune_alpha', type=float, default=0.1)
    parser.add_argument('-load_mask', type=str, default=None)

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']

    refs = read_instances_from_file(opt.target,
                                    preprocess_settings.max_word_seq_len,
                                    preprocess_settings.keep_case)

    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts,
    ),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    preds = []
    preds_text = []

    for batch in tqdm(test_loader,
                      mininterval=2,
                      desc='  - (Test)',
                      leave=False):
        all_hyp, all_scores = translator.translate_batch(*batch)
        for idx_seqs in all_hyp:
            for idx_seq in idx_seqs:
                sent = ' '.join(
                    [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
                sent = sent.split("</s>")[0].strip()
                sent = sent.replace("▁", " ")
                preds_text.append(sent.strip())
                preds.append(
                    [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq])
    with open(opt.output, 'w') as f:
        f.write('\n'.join(preds_text))

    from evaluator import BLEUEvaluator
    scorer = BLEUEvaluator()
    length = min(len(preds), len(refs))
    score = scorer.evaluate(refs[:length], preds[:length])
    print(score)
Beispiel #26
0
    in_lang_path = f"cache/in-fra-{max_len}-{min_freq}.pkl"
    out_lang_path = f"cache/out-eng-{max_len}-{min_freq}.pkl"
    pair_path = f"cache/fra2eng-{max_len}.pkl"
    exist_all = all(
        os.path.exists(path)
        for path in [in_lang_path, out_lang_path, pair_path])
    if not exist_all:
        data_prepare.prepare(max_len, min_freq)

    input_lang = Lang.load_from_file("fra", in_lang_path)
    output_lang = Lang.load_from_file("eng", out_lang_path)
    pairs = pickle.load(open(pair_path, "rb"))
    logger.info("\tinput_lang.n_words = {}".format(input_lang.n_words))
    logger.info("\toutput_lang.n_words = {}".format(output_lang.n_words))
    logger.info("\t# of pairs = {}".format(len(pairs)))
    dset = TranslationDataset(input_lang, output_lang, pairs, max_len)
    logger.info(random.choice(pairs))

    # split dset by valid indices
    N_pairs = len(pairs)
    val_indices_path = f"cache/valid_indices-{N_pairs}.npy"
    if not os.path.exists(val_indices_path):
        data_prepare.gen_valid_indices(N_pairs, 0.1, val_indices_path)
    valid_indices = np.load(val_indices_path)
    train_indices = list(set(range(len(dset))) - set(valid_indices))
    train_dset = Subset(dset, train_indices)
    valid_dset = Subset(dset, valid_indices)

    # loader
    collate_fn = src_sort if model_type == 'rnn' else torch.utils.data.dataloader.default_collate
    logger.info("Load loader")
def main():
    """Main Function"""

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument('-src',
                        required=True,
                        help='Source sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-tgt',
                        required=True,
                        help='Target sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-vocab',
                        required=True,
                        help='Source sequence to decode '
                        '(one line per sequence)')
    parser.add_argument('-log',
                        default='translate_log.txt',
                        help="""Path to log the translation(test_inference) 
                        loss""")
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=2, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_tgt_word_insts = read_instances_from_file(
        opt.tgt, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])
    test_tgt_insts = convert_instance_to_idx_seq(
        test_tgt_word_insts, preprocess_data['dict']['tgt'])

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts,
        tgt_insts=test_tgt_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=paired_collate_fn)

    translator = Translator(opt)

    n_word_total = 0
    n_word_correct = 0

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            # all_hyp, all_scores = translator.translate_batch(*batch)
            all_hyp, all_scores = translator.translate_batch(
                batch[0], batch[1])

            # print(all_hyp)
            # print(all_hyp[0])
            # print(len(all_hyp[0]))

            # pad with 0's fit to max_len in insts_group
            src_seqs = batch[0]
            # print(src_seqs.shape)
            tgt_seqs = batch[2]
            # print(tgt_seqs.shape)
            gold = tgt_seqs[:, 1:]
            # print(gold.shape)
            max_len = gold.shape[1]

            pred_seq = []
            for item in all_hyp:
                curr_item = item[0]
                curr_len = len(curr_item)
                # print(curr_len, max_len)
                # print(curr_len)
                if curr_len < max_len:
                    diff = max_len - curr_len
                    curr_item.extend([0] * diff)
                else:  # TODO: why does this case happen?
                    curr_item = curr_item[:max_len]
                pred_seq.append(curr_item)
            pred_seq = torch.LongTensor(np.array(pred_seq))
            pred_seq = pred_seq.view(opt.batch_size * max_len)

            n_correct = cal_performance(pred_seq, gold)

            non_pad_mask = gold.ne(Constants.PAD)
            n_word = non_pad_mask.sum().item()
            n_word_total += n_word
            n_word_correct += n_correct

            # trs_log = "transformer_loss: {} |".format(trs_loss)
            #
            # with open(opt.log, 'a') as log_tf:
            #     log_tf.write(trs_log + '\n')

            count = 0
            for pred_seqs in all_hyp:
                src_seq = src_seqs[count]
                tgt_seq = tgt_seqs[count]
                for pred_seq in pred_seqs:
                    src_line = ' '.join([
                        test_loader.dataset.src_idx2word[idx]
                        for idx in src_seq.data.cpu().numpy()
                    ])
                    tgt_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in tgt_seq.data.cpu().numpy()
                    ])
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in pred_seq
                    ])
                    f.write(
                        "\n ----------------------------------------------------------------------------------------------------------------------------------------------  \n"
                    )
                    f.write("\n [src]  " + src_line + '\n')
                    f.write("\n [tgt]  " + tgt_line + '\n')
                    f.write("\n [pred] " + pred_line + '\n')

                    count += 1

        accuracy = n_word_correct / n_word_total
        accr_log = "accuracy: {} |".format(accuracy)
        # print(accr_log)

        with open(opt.log, 'a') as log_tf:
            log_tf.write(accr_log + '\n')

    print('[Info] Finished.')
Beispiel #28
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='translate.py')

    parser.add_argument('-model', required=True, help='Path to model .pt file')
    parser.add_argument(
        '-src',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument(
        '-vocab',
        required=True,
        help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-output',
                        default='pred.txt',
                        help="""Path to output the predictions (each line will
                        be the decoded sequence""")
    parser.add_argument('-beam_size', type=int, default=5, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=30, help='Batch size')
    parser.add_argument('-n_best',
                        type=int,
                        default=1,
                        help="""If verbose is set, will output the n_best
                        decoded sentences""")
    parser.add_argument('-no_cuda', action='store_true')

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda

    # Prepare DataLoader
    preprocess_data = torch.load(opt.vocab)
    preprocess_settings = preprocess_data['settings']
    test_src_word_insts = read_instances_from_file(
        opt.src, preprocess_settings.max_word_seq_len,
        preprocess_settings.keep_case)
    test_src_insts = convert_instance_to_idx_seq(
        test_src_word_insts, preprocess_data['dict']['src'])

    # pdb.set_trace()
    # (Pdb) print(opt)
    # Namespace(batch_size=30, beam_size=5, cuda=True, model='trained.chkpt',
    #     n_best=1, no_cuda=False, output='pred.txt', src='data/multi30k/test.en.atok',
    #     vocab='data/multi30k.atok.low.pt')

    test_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=test_src_insts),
                                              num_workers=2,
                                              batch_size=opt.batch_size,
                                              collate_fn=collate_fn)

    translator = Translator(opt)

    with open(opt.output, 'w') as f:
        for batch in tqdm(test_loader,
                          mininterval=2,
                          desc='  - (Test)',
                          leave=False):
            all_hyp, all_scores = translator.translate_batch(*batch)
            for idx_seqs in all_hyp:
                for idx_seq in idx_seqs:
                    pred_line = ' '.join([
                        test_loader.dataset.tgt_idx2word[idx]
                        for idx in idx_seq
                    ])
                    f.write(pred_line + '\n')
    print('[Info] Finished.')
Beispiel #29
0
def prepare_dataloaders(data, opt):
    # ========= Preparing DataLoader =========#

    N = data['settings']['n_instances']
    train_len = int(N * opt.split)
    start_idx = int(opt.offset * N)
    print("Data split: {}".format(opt.split))
    print("Training starts at: {} out of {} instances".format(start_idx, N))

    if start_idx + train_len < N:
        train_src_insts = data['src'][start_idx: start_idx + train_len]
        train_tgt_insts = data['tgt'][start_idx: start_idx + train_len]
        train_tgt_nums = data['tgt_nums'][start_idx: start_idx + train_len]

        valid_src_insts = data['src'][start_idx + train_len:] + data['src'][:start_idx]
        valid_tgt_insts = data['tgt'][start_idx + train_len:] + data['tgt'][:start_idx]
        valid_tgt_nums = data['tgt_nums'][start_idx + train_len:] + data['tgt_nums'][:start_idx]
    else:
        valid_len = N - train_len
        valid_start_idx = start_idx - valid_len

        train_src_insts = data['src'][start_idx:] + data['src'][:valid_start_idx]
        train_tgt_insts = data['tgt'][start_idx:] + data['tgt'][:valid_start_idx]
        train_tgt_nums = data['tgt_nums'][start_idx:] + data['tgt_nums'][:valid_start_idx]

        valid_src_insts = data['src'][valid_start_idx: start_idx]
        valid_tgt_insts = data['tgt'][valid_start_idx: start_idx]
        valid_tgt_nums = data['tgt_nums'][valid_start_idx: start_idx]

    train_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=train_src_insts,
            tgt_insts=train_tgt_insts,
            tgt_nums=train_tgt_nums,
            permute_tgt=False),
        num_workers=2,
        batch_size=opt.batch_size,
        # collate_fn=collate_fn,
        shuffle=True)

    valid_loader = torch.utils.data.DataLoader(
        TranslationDataset(
            src_word2idx=data['dict']['src'],
            tgt_word2idx=data['dict']['tgt'],
            src_insts=valid_src_insts,
            tgt_insts=valid_tgt_insts,
            tgt_nums=valid_tgt_nums,
            permute_tgt=False),
        num_workers=2,
        batch_size=opt.batch_size)
        # collate_fn=collate_fn)

    if opt.bi:
        train_loader.collate_fn = train_loader.dataset.bidirectional_collate_fn
        valid_loader.collate_fn = valid_loader.dataset.bidirectional_collate_fn
    else:
        train_loader.collate_fn = train_loader.dataset.paired_collate_fn
        valid_loader.collate_fn = valid_loader.dataset.paired_collate_fn

    return train_loader, valid_loader
Beispiel #30
0
def main():
    '''Main Function'''

    parser = argparse.ArgumentParser(description='reinforcement training')

    parser.add_argument('-model',
                        required=True,
                        help='Path to pretrained model .pt file')
    # parser.add_argument('-src', required=True,
    #                     help='Source sequence to decode (one line per sequence)')
    parser.add_argument('-data', required=True, help='preprocessed data file')
    parser.add_argument(
        '-original_data',
        default=config.FORMATTED_DATA,
        help='original data showing original text and equations')
    parser.add_argument(
        '-vocab',
        default=None,
        help=
        'data file for vocabulary. if not specified (default), use the one in -data'
    )
    parser.add_argument(
        '-split',
        type=float,
        default=0.8,
        help='proprotion of training data. the rest is test data.')
    parser.add_argument(
        '-offset',
        type=float,
        default=0,
        help="determin starting index of training set, for cross validation")
    parser.add_argument('-save_model',
                        default=None,
                        help="model destination path")
    parser.add_argument('-beam_size', type=int, default=8, help='Beam size')
    parser.add_argument('-batch_size', type=int, default=4, help='Batch size')
    parser.add_argument(
        '-n_best',
        type=int,
        default=8,
        help="If verbose is set, will output the n_best decoded sentences")
    parser.add_argument('-no_cuda', action='store_true')
    parser.add_argument('-epochs', type=int, default=100)
    parser.add_argument('-teacher_ratio',
                        type=float,
                        default=0.,
                        help="probability to allow teacher forcing")
    parser.add_argument('-permute',
                        action='store_true',
                        help="permute equations for training")

    opt = parser.parse_args()
    opt.cuda = not opt.no_cuda
    opt.reset_num = True  # use numbers (not symbols) in output
    print(opt)

    # Prepare DataLoader
    preprocess_data = torch.load(opt.data)
    if opt.original_data is not None:
        formatted_data = json.load(open(opt.original_data))
        formatted_map = {}
        for d in formatted_data:
            formatted_map[d['id']] = d

    N = preprocess_data['settings']['n_instances']
    train_len = int(N * opt.split)
    start_idx = int(opt.offset * N)
    print("Data split: {}".format(opt.split))
    print("Training starts at: {} out of {} instances".format(start_idx, N))

    if start_idx + train_len < N:
        train_src_insts = preprocess_data['src'][start_idx:start_idx +
                                                 train_len]
        train_tgt_insts = preprocess_data['tgt'][start_idx:start_idx +
                                                 train_len]
        train_tgt_nums = preprocess_data['tgt_nums'][start_idx:start_idx +
                                                     train_len]
    else:
        valid_len = N - train_len
        valid_start_idx = start_idx - valid_len

        train_src_insts = preprocess_data['src'][start_idx:] + preprocess_data[
            'src'][:valid_start_idx]
        train_tgt_insts = preprocess_data['tgt'][start_idx:] + preprocess_data[
            'tgt'][:valid_start_idx]
        train_tgt_nums = preprocess_data['tgt_nums'][
            start_idx:] + preprocess_data['tgt_nums'][:valid_start_idx]

    data_loader = torch.utils.data.DataLoader(TranslationDataset(
        src_word2idx=preprocess_data['dict']['src'],
        tgt_word2idx=preprocess_data['dict']['tgt'],
        src_insts=train_src_insts,
        tgt_insts=train_tgt_insts,
        tgt_nums=train_tgt_nums,
        permute_tgt=False),
                                              num_workers=1,
                                              batch_size=opt.batch_size)
    # collate_fn=collate_fn)
    # data_loader.collate_fn = data_loader.dataset.collate_fn
    data_loader.collate_fn = data_loader.dataset.bidirectional_collate_fn

    # tgt_insts = preprocess_data['tgt'][:train_len]
    # block_list = [preprocess_data['dict']['tgt'][UNK_WORD]]

    translator = Translator(opt)
    original_max_token_seq_len = translator.model_opt.max_token_seq_len
    translator.model.train()

    # set teacher forcing training optimizer
    optimizer_teacher = Scheduler(optim.Adam(filter(
        lambda x: x.requires_grad, translator.model.parameters()),
                                             betas=(0.9, 0.98),
                                             eps=1e-09),
                                  alpha=1e-6)
    # set reinforcement training optimizer
    optimizer_reinforce = Scheduler(optim.Adam(filter(
        lambda x: x.requires_grad, translator.model.parameters()),
                                               betas=(0.9, 0.98),
                                               eps=1e-09),
                                    alpha=5e-7)  # 1e-8

    for epoch in range(opt.epochs):
        start = time.time()
        instance_idx = start_idx
        n_correct = 0
        total_loss = 0
        optimizer_reinforce.n_current_steps += 1

        # for gcl
        translator.model.encoder.gcl.init_sequence(1)
        translator.model.encoder.memory_ready = False

        for batch in tqdm(data_loader,
                          mininterval=2,
                          desc='  - (Train)',
                          leave=True):
            # batch: (*src_insts, *tgt_insts, *tgt_nums_insts)
            # print(batch[0]);sys.exit(1)
            translator.model_opt.max_token_seq_len = 32  # make training managable
            all_hyp_list, all_score_list = translator.translate_batch(
                batch[0], batch[1], block_list=[])

            # reinforcement training
            batch_loss, batch_n_correct = train_batch(
                all_hyp_list, all_score_list, translator, data_loader,
                preprocess_data, formatted_map, instance_idx, opt)
            optimizer_reinforce.zero_grad()

            # # for gcl
            # memory = translator.model.encoder.gcl.memory
            # print(memory[-1])
            # translator.model.encoder.gcl.init_sequence(1)
            # translator.model.encoder.gcl.memory = memory
            # translator.model.encoder.gcl.gcl.meory = memory
            #for head in translator.model.encoder.gcl.gcl.heads:
            #    head.memory = memory

            batch_loss.backward()
            optimizer_reinforce.step_and_update_lr()

            total_loss += batch_loss.item()
            n_correct += batch_n_correct
            instance_idx += opt.batch_size
            instance_idx = instance_idx % N

            # if batch_n_correct / opt.batch_size < 0.3:
            #     # teacher forceing training
            #     teacher_train_batch(translator.model, batch, optimizer_teacher, translator.device,
            #                         bidirectional=translator.opt.bi)

        # end of epoch
        train_acc = n_correct / train_len
        total_loss = total_loss * opt.batch_size / train_len
        sys.stdout.write('\n  - (Training)   ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \
              'elapse: {elapse:3.3f} min\n'.format(
            ppl=math.exp(min(total_loss, 100)), accu=100 * train_acc, elapse=(time.time() - start) / 60))
        sys.stdout.flush()

        model_state_dict = translator.model.state_dict()
        translator.model_opt.max_token_seq_len = original_max_token_seq_len
        checkpoint = {
            'model': model_state_dict,
            'memory': translator.model.encoder.gcl.memory,
            'settings': translator.model_opt,
            'epoch': epoch
        }
        model_name = opt.save_model + '.chkpt'
        torch.save(checkpoint, model_name)