def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# # num_workers表示线程数量 # collate_fn,是用来处理不同情况下的输入dataset的封装, # 一般采用默认即可,除非你自定义的数据读取输出非常少见 # 跳过collate_fn train_loader = torch.utils.data.DataLoader( # TranslateionDataset参数中前两个是索引,后两个是数据, 其它的都不重要,重要的是这个类 # 必须实现Dataset的接口,即__len__方法与__getitem__方法 # len方法用来获取数据集长度即src_insts长度, getitem(i), 用来获取第i个数据, # 即(src_insts[i], tgt_insts[i]) # 这里写paired_collate_fn函数的原因应该是getitem(index)方法返回的 # 不是单个数据,而是一个元组 # shuffle : set to True to have the data reshuffled at every epoch # shuffle使得每轮训练取得的batch顺序不同 TranslationDataset(src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), # load data用到的线程数为2 num_workers=2, # batch_size此处为64 batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt, distributed): # ========= Preparing DataLoader =========# train_dataset = TranslationDataset(src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']) if distributed: train_sampler = torch.utils.data.distributed.DistributedSampler( train_dataset) else: train_sampler = None train_loader = torch.utils.data.DataLoader(train_dataset, num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=train_sampler is None, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=int(opt.num_workers), batch_size=opt.batch_size, collate_fn=paired_collate_fn, pin_memory=True, shuffle=False) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=int(opt.num_workers), batch_size=opt.batch_size, pin_memory=True, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): print(data["settings"]) train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data["dict"]["src"], tgt_word2idx=data["dict"]["tgt"], src_insts=data["train"]["src"], tgt_insts=data["train"]["tgt"], ), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True, ) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data["dict"]["src"], tgt_word2idx=data["dict"]["tgt"], src_insts=data["valid"]["src"], tgt_insts=data["valid"]["tgt"], ), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, ) src_vocab_size = train_loader.dataset.src_vocab_size trg_vocab_size = train_loader.dataset.tgt_vocab_size src_idx2word = {idx: word for word, idx in data['dict']['src'].items()} trg_idx2word = {idx: word for word, idx in data['dict']['tgt'].items()} return train_loader, valid_loader, src_vocab_size, trg_vocab_size, src_idx2word, trg_idx2word
def prepare_dataloaders(data, opt): validation_split = 0.1 shuffle_dataset = True random_seed = 42 initDataset = TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt'], sp_insts=data['train']['sp']) # Creating data indices for training and validation splits: dataset_size = len(initDataset) indices = list(range(dataset_size)) split = int(np.floor(validation_split * dataset_size)) if shuffle_dataset: np.random.seed(random_seed) np.random.shuffle(indices) train_indices, val_indices = indices[split:], indices[:split] # Creating PT data samplers and loaders: train_sampler = SubsetRandomSampler(train_indices) valid_sampler = SubsetRandomSampler(val_indices) # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( initDataset, num_workers=4, batch_size=opt.batch_size, collate_fn=paired_collate_fn, sampler=train_sampler) valid_loader = torch.utils.data.DataLoader( initDataset, num_workers=4, batch_size=opt.batch_size, collate_fn=paired_collate_fn, sampler=valid_sampler) test_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt'], sp_insts=data['valid']['sp'] ), num_workers=4, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader, test_loader
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) # for j in train_loader: # print(">>>>>>>>>>>>>>>>",len(j), j[0].shape, j[1].shape) >>>>>>>>>>>>>>>> 4 torch.Size([64, 51]) torch.Size([64, 51]) # break # (tensor([[ 2, 24, 1, ..., 0, 0, 0], # [ 2, 20, 1, ..., 0, 0, 0], # [ 2, 1, 1, ..., 0, 0, 0], # ..., # [ 2, 1, 1, ..., 0, 0, 0], # [ 2, 33, 26, ..., 0, 0, 0], # [ 2, 13, 25, ..., 0, 0, 0]]), tensor([[1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # ..., # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0]]), tensor([[ 2, 1, 33, ..., 0, 0, 0], # [ 2, 1, 1, ..., 0, 0, 0], # [ 2, 1, 1, ..., 0, 0, 0], # ..., # [ 2, 24, 34, ..., 0, 0, 0], # [ 2, 33, 1, ..., 0, 0, 0], # [ 2, 14, 5, ..., 0, 0, 0]]), tensor([[1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # ..., # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0], # [1, 2, 3, ..., 0, 0, 0]])) return train_loader, valid_loader
def test(opt): """ Functions to test the model and implement machine translation""" # Prepare DataLoader preprocess_data = t.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = t.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = TransformerTranslator(opt) with open(opt.output, 'w', encoding='utf-8') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred_line = ' '.join([ test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq ]) f.write(pred_line + '\n') print('[Info] Finished.')
def main(): parser = argparse.ArgumentParser() parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument( '-src', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument( '-vocab', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') # 有动作就设置为true opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for hyp_stream in all_hyp: for hyp in hyp_stream: pred_sent = ' '.join( [test_loader.dataset.tgt_idx2word[idx] for idx in hyp]) f.write(pred_sent + '\n') print('[Info] Finished')
def prepare_dataloaders(data, mined_data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) test_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['test']['src'], tgt_insts=data['test']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) mined_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=mined_data['dict']['src'], tgt_word2idx=mined_data['dict']['tgt'], src_insts=mined_data['train']['src'], tgt_insts=mined_data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader, test_loader, mined_loader
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict'], tgt_word2idx=data['dict'], # same for language modelling src_insts=data['train']), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict'], tgt_word2idx=data['dict'], # same word2idx for language modelling src_insts=data['train']), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): # 将数据进行处理 print("整合数据...") src_pre = data['train']['src'] tgt_pre = data['train']['tgt'] all_data = list(zip(src_pre, tgt_pre)) print("sample 数据...") sampler = BatchSampler(SequentialSampler(all_data), batch_size=opt.batch_size, drop_last=False) index = [s for s in sampler] random.shuffle(index) index = list(itertools.chain.from_iterable(index)) print("重新赋值...") src = [] tgt = [] for i in index: src.append(src_pre[i]) tgt.append(tgt_pre[i]) data['train']['src'] = src data['train']['tgt'] = tgt # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=0, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=False) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=0, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): train_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): #pass #把一个dataset封装进prepare_dataloaders里面,dataset 有seq(bh, lens) 有word2idx 有idx2word train_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( # 跟平时处理一样,还是要自己定义一个数据集的类.再用dataloader来加载. TranslationDataset( # 把train数据放入数据集中. src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) # 这个数据集只有train 和valid 没有test valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def prepare_dataloaders(data, opt): ''' Prepare Pytorch dataloaders ''' train_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['train']['src'], tgt_insts=data['train']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, drop_last=False, shuffle=True) valid_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=data['valid']['src'], tgt_insts=data['valid']['tgt']), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, drop_last=False) return train_loader, valid_loader
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data["dict"]["src"], tgt_word2idx=data["dict"]["tgt"], src_insts=data["train"]["src"], tgt_insts=data["train"]["tgt"]), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data["dict"]["src"], tgt_word2idx=data["dict"]["tgt"], src_insts=data["valid"]["src"], tgt_insts=data["valid"]["tgt"]), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) return train_loader, valid_loader
def evaluateAndShowAttention(in_s, seq2seq, in_lang, out_lang, out_file): seq2seq.eval() src = TranslationDataset.to_ids(in_s, in_lang) + [EOS_idx] src_len = len(src) src = torch.LongTensor(src).view(1, -1).cuda() src_len = torch.tensor([src_len]) dec_outs, attn_ws = seq2seq.generate(src, src_len) topi = dec_outs.topk(1)[1] # [1, max_len, 1] out_words = idx2words(topi.squeeze(), out_lang) logger.info("input = {}".format(in_s)) logger.info("output = {}".format(' '.join(out_words))) attn_ws = attn_ws.squeeze().detach().cpu()[:len(out_words)] image = showAttention(in_s, out_words, attn_ws, out_file) return attn_ws, image
def main(): """Main Function""" parser = argparse.ArgumentParser(description="translate.py") parser.add_argument("-model", required=True, help="Path to model .pt file") parser.add_argument( "-src", required=True, help="Source sequence to decode (one line per sequence)") parser.add_argument( "-vocab", required=True, help="Source sequence to decode (one line per sequence)") parser.add_argument("-output", default="pred.txt", help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument("-beam_size", type=int, default=5, help="Beam size") parser.add_argument("-batch_size", type=int, default=30, help="Batch size") parser.add_argument("-n_best", type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument("-no_cuda", action="store_true") opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data["settings"] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data["dict"]["src"]) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data["dict"]["src"], tgt_word2idx=preprocess_data["dict"]["tgt"], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, "w") as f: for batch in tqdm(test_loader, mininterval=2, desc=" - (Test)", leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred_line = " ".join([ test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq ]) f.write(pred_line + "\n") print("[Info] Finished.")
def main(): ''' Main function ''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .chkpt file') parser.add_argument('-test_file', required=True, help='Test pickle file for validation') parser.add_argument( '-output', default='outputs.txt', help= 'Path to output the predictions (each line will be the decoded sequence' ) parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=16, help='Batch size') parser.add_argument( '-n_best', type=int, default=1, help='If verbose is set, will output the n_best decoded sentences') parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda #- Prepare Translator translator = Translator(opt) print('[Info] Model opts: {}'.format(translator.model_opt)) #- Prepare DataLoader test_data = torch.load(opt.test_file) test_src_insts = test_data['test']['src'] test_tgt_insts = test_data['test']['tgt'] test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=test_data['dict']['src'], tgt_word2idx=test_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, drop_last=True, collate_fn=collate_fn) print('[Info] Evaluate on test set.') with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Testing)', leave=False): all_hyp, all_scores = translator.translate_batch( *batch) # structure: List[batch, seq, pos] for inst in all_hyp: f.write('[') for seq in inst: seq = seq[0] pred_seq = ' '.join([ test_loader.dataset.tgt_idx2word[word] for word in seq ]) f.write('\t' + pred_seq + '\n') f.write(']\n') print('[Info] Finished.')
parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs:
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument('-data_dir', required=True) parser.add_argument('-debug', action='store_true') parser.add_argument('-dir_out', default="/home/suster/Apps/out/") parser.add_argument( "--convert-consts", type=str, help="conv | our-map | no-our-map | no. \n/" "conv-> txt: -; stats: num_sym+ent_sym.\n/" "our-map-> txt: num_sym; stats: num_sym(from map)+ent_sym;\n/" "no-our-map-> txt: -; stats: num_sym(from map)+ent_sym;\n/" "no-> txt: -; stats: -, only ent_sym;\n/" "no-ent-> txt: -; stats: -, no ent_sym;\n/") parser.add_argument( "--label-type-dec", type=str, default="full-pl", help= "predicates | predicates-all | predicates-arguments-all | full-pl | full-pl-no-arg-id | full-pl-split | full-pl-split-plc | full-pl-split-stat-dyn. To use with EncDec." ) parser.add_argument('-vocab', required=True) #parser.add_argument('-output', default='pred.txt', # help="""Path to output the predictions (each line will # be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') args = parser.parse_args() args.cuda = not args.no_cuda # Prepare DataLoader preprocess_data = torch.load(args.vocab) preprocess_settings = preprocess_data['settings'] if args.convert_consts in {"conv"}: assert "nums_mapped" not in args.data_dir elif args.convert_consts in {"our-map", "no-our-map", "no", "no-ent"}: assert "nums_mapped" in args.data_dir else: if args.convert_consts is not None: raise ValueError test_corp = Nlp4plpCorpus(args.data_dir + "test", args.convert_consts) if args.debug: test_corp.insts = test_corp.insts[:10] test_corp.get_labels(label_type=args.label_type_dec) test_corp.remove_none_labels() # Training set test_src_word_insts, test_src_id_insts = prepare_instances(test_corp.insts) test_tgt_word_insts, test_tgt_id_insts = prepare_instances(test_corp.insts, label=True) assert test_src_id_insts == test_tgt_id_insts test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=0, batch_size=args.batch_size, collate_fn=collate_fn) translator = Translator(args) i = 0 preds = [] golds = [] for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred = [ test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq if test_loader.dataset.tgt_idx2word[idx] != "</s>" ] gold = [ w for w in test_tgt_word_insts[i] if w not in {"<s>", "</s>"} ] if args.convert_consts == "no": num2n = None else: id = test_src_id_insts[i] assert test_corp.insts[i].id == id num2n = test_corp.insts[i].num2n_map pred = final_repl(pred, num2n) gold = final_repl(gold, num2n) preds.append(pred) golds.append(gold) i += 1 acc = accuracy_score(golds, preds) print(f"Accuracy: {acc:.3f}") print("Saving predictions from the best model:") assert len(test_src_id_insts) == len(test_src_word_insts) == len( preds) == len(golds) f_model = f'{datetime.now().strftime("%Y%m%d_%H%M%S_%f")}' dir_out = f"{args.dir_out}log_w{f_model}/" print(f"Save preds dir: {dir_out}") if not os.path.exists(dir_out): os.makedirs(dir_out) for (id, gold, pred) in zip(test_src_id_insts, golds, preds): f_name_t = os.path.basename(f"{id}.pl_t") f_name_p = os.path.basename(f"{id}.pl_p") with open(dir_out + f_name_t, "w") as f_out_t, open(dir_out + f_name_p, "w") as f_out_p: f_out_t.write(gold) f_out_p.write(pred) #with open(args.output, 'w') as f: # golds # preds # f.write("PRED: " + pred_line + '\n') # f.write("GOLD: " + gold_line + '\n') print('[Info] Finished.')
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='predict.py') parser.add_argument('-model', required=True, help='Path to model .pt file') # parser.add_argument('-src', required=True, # help='Source sequence to decode (one line per sequence)') parser.add_argument('-data', required=True, help='preprocessed data file') parser.add_argument( '-original_data', default=config.FORMATTED_DATA, help='original data showing original text and equations') parser.add_argument( '-vocab', default=None, help= 'data file for vocabulary. if not specified (default), use the one in -data' ) parser.add_argument( '-split', type=float, default=0.8, help='proprotion of training data. the rest is test data.') parser.add_argument( '-offset', type=float, default=0, help="determin starting index of training set, for cross validation") parser.add_argument( '-output', default='pred.json', help= """Path to output the predictions (each line will be the decoded sequence""" ) parser.add_argument('-beam_size', type=int, default=10, help='Beam size') parser.add_argument('-batch_size', type=int, default=64, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-reset_num', default=False, action='store_true', help='replace number symbols with real numbers') parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda print(opt) # Prepare DataLoader preprocess_data = torch.load(opt.data) if opt.original_data is not None: formmated_data = json.load(open(opt.original_data)) formmated_map = {} for d in formmated_data: formmated_map[d['id']] = d N = preprocess_data['settings']['n_instances'] train_len = int(N * opt.split) start_idx = int(opt.offset * N) # start location of training data print("Data split: {}".format(opt.split)) print("Training starts at: {} out of {} instances".format(start_idx, N)) if start_idx + train_len < N: valid_src_insts = preprocess_data['src'][ start_idx + train_len:] + preprocess_data['src'][:start_idx] valid_tgt_insts = preprocess_data['tgt'][ start_idx + train_len:] + preprocess_data['tgt'][:start_idx] else: valid_len = N - train_len valid_start_idx = start_idx - valid_len valid_src_insts = preprocess_data['src'][valid_start_idx:start_idx] valid_tgt_insts = preprocess_data['tgt'][valid_start_idx:start_idx] test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=valid_src_insts), num_workers=2, batch_size=opt.batch_size) # collate_fn=collate_fn) test_loader.collate_fn = test_loader.dataset.collate_fn tgt_insts = valid_tgt_insts block_list = [preprocess_data['dict']['tgt'][UNK_WORD]] translator = Translator(opt) # translator = NTMTranslator(opt) translator.model.eval() output = [] n = 0 for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): with torch.no_grad(): all_hyp_list, all_score_list = translator.translate_batch( *batch, block_list=block_list) for i, idx_seqs in enumerate( all_hyp_list[0]): # loop over instances in batch scores = all_score_list[0][i] if translator.opt.bi: # bidirectional idx_seqs_reverse = all_hyp_list[1][i] scores_reverse = all_score_list[1][i] for j, idx_seq in enumerate(idx_seqs): # loop over n_best results d = {} question_id = preprocess_data['idx2id'][(n + train_len + start_idx) % N] pred_line = ''.join( [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) score = scores[j] if translator.opt.bi: idx_seq_reverse = idx_seqs_reverse[j] score_reverse = scores_reverse[j] idx_seq_reverse.reverse() pred_line_reverse = ''.join([ test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq_reverse ]) src_idx_seq = test_loader.dataset[n] # truth src_text = ' '.join([ test_loader.dataset.src_idx2word[idx] for idx in src_idx_seq ]) tgt_text = ''.join([ test_loader.dataset.tgt_idx2word[idx] for idx in tgt_insts[n] ]) if opt.reset_num: src_text = reset_numbers( src_text, preprocess_data['numbers'][(n + train_len + start_idx) % N]) # tgt_text = reset_numbers(tgt_text, preprocess_data['numbers'][n + train_len]) tgt_text = ';'.join( formmated_map[question_id]['equations']) pred_line = reset_numbers( pred_line, preprocess_data['numbers'][(n + train_len + start_idx) % N], try_similar=True) if translator.opt.bi: pred_line_reverse = reset_numbers( pred_line_reverse, preprocess_data['numbers'][(n + train_len + start_idx) % N], try_similar=True) # print(pred_line, tgt_text) # print(pred_line_reverse, tgt_text, '\n') d['question'] = src_text d['ans'] = preprocess_data['ans'][(n + train_len + start_idx) % N] d['id'] = question_id d['equation'] = tgt_text d['pred'] = (pred_line.replace('</s>', ''), round(score.item(), 3)) if translator.opt.bi: d['pred_2'] = (pred_line_reverse.replace('</s>', ''), round(score_reverse.item(), 3)) output.append(d) n += 1 with open(opt.output, 'w') as f: json.dump(output, f, indent=2) print('[Info] Finished.')
def main(): '''Main Function''' ''' 这个模型是从英语到德语. ''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=False, help='Path to model .pt file') parser.add_argument('-src', required=False, help='Source sequence to decode (one line per sequence)') parser.add_argument('-vocab', required=False, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='2', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') #-vocab data/multi30k.atok.low.pt opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.cuda=False opt.model='trained.chkpt' opt.src='1' opt.vocab='multi30k.atok.low.pt' # Prepare DataLoader preprocess_data = torch.load(opt.vocab) tmp1=preprocess_data['dict']['src'] tmp2=preprocess_data['dict']['tgt'] with open('55','w')as f: f.write(str(tmp1)) with open('66','w',encoding='utf-8')as f: f.write(str(tmp2)) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, 'w') as f: for batch in test_loader: all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: print(idx_seq) pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) # 把id转化会text f.write(pred_line + '\n') print('[Info] Finished.')
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument('-vocab', required=True, help='Path to vocabulary file') parser.add_argument('-output', help="""Path to output the predictions""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda src_line = "Binary files a / build / linux / jre . tgz and b / build / linux / jre . tgz differ <nl>" # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances( src_line, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=1, collate_fn=collate_fn) translator = Translator(opt) for batch in tqdm(test_loader, mininterval=1, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred_line = ' '.join([test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq[:-1]]) print(pred_line) sent = src_line.split() tgt_sent = pred_line.split() for layer in range(0, 2): fig, axs = plt.subplots(1,4, figsize=(20, 10)) print("Encoder Layer", layer+1) for h in range(4): print(translator.model.encoder.layer_stack[layer].slf_attn.attn.data.cpu().size()) draw(translator.model.encoder.layer_stack[layer].slf_attn.attn[h, :, :].data.cpu(), sent, sent if h ==0 else [], ax=axs[h]) plt.savefig(opt.output+"Encoder Layer %d.png" % layer) for layer in range(0, 2): fig, axs = plt.subplots(1,4, figsize=(20, 10)) print("Decoder Self Layer", layer+1) for h in range(4): print(translator.model.decoder.layer_stack[layer].slf_attn.attn.data.cpu().size()) draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(tgt_sent), :len(tgt_sent)].cpu(), tgt_sent, tgt_sent if h ==0 else [], ax=axs[h]) plt.savefig(opt.output+"Decoder Self Layer %d.png" % layer) print("Decoder Src Layer", layer+1) fig, axs = plt.subplots(1,4, figsize=(20, 10)) for h in range(4): draw(translator.model.decoder.layer_stack[layer].slf_attn.attn[:,:, h].data[:len(sent), :len(tgt_sent)].cpu(), tgt_sent, sent if h ==0 else [], ax=axs[h]) plt.savefig(opt.output+"Decoder Src Layer %d.png" % layer) print('[Info] Finished.')
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument( '-src', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument( '-target', required=True, help='Target sequence to decode (one line per sequence)') parser.add_argument( '-vocab', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-prune', action='store_true') parser.add_argument('-prune_alpha', type=float, default=0.1) parser.add_argument('-load_mask', type=str, default=None) opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] refs = read_instances_from_file(opt.target, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts, ), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) preds = [] preds_text = [] for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: sent = ' '.join( [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) sent = sent.split("</s>")[0].strip() sent = sent.replace("▁", " ") preds_text.append(sent.strip()) preds.append( [test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq]) with open(opt.output, 'w') as f: f.write('\n'.join(preds_text)) from evaluator import BLEUEvaluator scorer = BLEUEvaluator() length = min(len(preds), len(refs)) score = scorer.evaluate(refs[:length], preds[:length]) print(score)
in_lang_path = f"cache/in-fra-{max_len}-{min_freq}.pkl" out_lang_path = f"cache/out-eng-{max_len}-{min_freq}.pkl" pair_path = f"cache/fra2eng-{max_len}.pkl" exist_all = all( os.path.exists(path) for path in [in_lang_path, out_lang_path, pair_path]) if not exist_all: data_prepare.prepare(max_len, min_freq) input_lang = Lang.load_from_file("fra", in_lang_path) output_lang = Lang.load_from_file("eng", out_lang_path) pairs = pickle.load(open(pair_path, "rb")) logger.info("\tinput_lang.n_words = {}".format(input_lang.n_words)) logger.info("\toutput_lang.n_words = {}".format(output_lang.n_words)) logger.info("\t# of pairs = {}".format(len(pairs))) dset = TranslationDataset(input_lang, output_lang, pairs, max_len) logger.info(random.choice(pairs)) # split dset by valid indices N_pairs = len(pairs) val_indices_path = f"cache/valid_indices-{N_pairs}.npy" if not os.path.exists(val_indices_path): data_prepare.gen_valid_indices(N_pairs, 0.1, val_indices_path) valid_indices = np.load(val_indices_path) train_indices = list(set(range(len(dset))) - set(valid_indices)) train_dset = Subset(dset, train_indices) valid_dset = Subset(dset, valid_indices) # loader collate_fn = src_sort if model_type == 'rnn' else torch.utils.data.dataloader.default_collate logger.info("Load loader")
def main(): """Main Function""" parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument('-src', required=True, help='Source sequence to decode ' '(one line per sequence)') parser.add_argument('-tgt', required=True, help='Target sequence to decode ' '(one line per sequence)') parser.add_argument('-vocab', required=True, help='Source sequence to decode ' '(one line per sequence)') parser.add_argument('-log', default='translate_log.txt', help="""Path to log the translation(test_inference) loss""") parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=2, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_tgt_word_insts = read_instances_from_file( opt.tgt, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) test_tgt_insts = convert_instance_to_idx_seq( test_tgt_word_insts, preprocess_data['dict']['tgt']) test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts, tgt_insts=test_tgt_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=paired_collate_fn) translator = Translator(opt) n_word_total = 0 n_word_correct = 0 with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): # all_hyp, all_scores = translator.translate_batch(*batch) all_hyp, all_scores = translator.translate_batch( batch[0], batch[1]) # print(all_hyp) # print(all_hyp[0]) # print(len(all_hyp[0])) # pad with 0's fit to max_len in insts_group src_seqs = batch[0] # print(src_seqs.shape) tgt_seqs = batch[2] # print(tgt_seqs.shape) gold = tgt_seqs[:, 1:] # print(gold.shape) max_len = gold.shape[1] pred_seq = [] for item in all_hyp: curr_item = item[0] curr_len = len(curr_item) # print(curr_len, max_len) # print(curr_len) if curr_len < max_len: diff = max_len - curr_len curr_item.extend([0] * diff) else: # TODO: why does this case happen? curr_item = curr_item[:max_len] pred_seq.append(curr_item) pred_seq = torch.LongTensor(np.array(pred_seq)) pred_seq = pred_seq.view(opt.batch_size * max_len) n_correct = cal_performance(pred_seq, gold) non_pad_mask = gold.ne(Constants.PAD) n_word = non_pad_mask.sum().item() n_word_total += n_word n_word_correct += n_correct # trs_log = "transformer_loss: {} |".format(trs_loss) # # with open(opt.log, 'a') as log_tf: # log_tf.write(trs_log + '\n') count = 0 for pred_seqs in all_hyp: src_seq = src_seqs[count] tgt_seq = tgt_seqs[count] for pred_seq in pred_seqs: src_line = ' '.join([ test_loader.dataset.src_idx2word[idx] for idx in src_seq.data.cpu().numpy() ]) tgt_line = ' '.join([ test_loader.dataset.tgt_idx2word[idx] for idx in tgt_seq.data.cpu().numpy() ]) pred_line = ' '.join([ test_loader.dataset.tgt_idx2word[idx] for idx in pred_seq ]) f.write( "\n ---------------------------------------------------------------------------------------------------------------------------------------------- \n" ) f.write("\n [src] " + src_line + '\n') f.write("\n [tgt] " + tgt_line + '\n') f.write("\n [pred] " + pred_line + '\n') count += 1 accuracy = n_word_correct / n_word_total accr_log = "accuracy: {} |".format(accuracy) # print(accr_log) with open(opt.log, 'a') as log_tf: log_tf.write(accr_log + '\n') print('[Info] Finished.')
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='translate.py') parser.add_argument('-model', required=True, help='Path to model .pt file') parser.add_argument( '-src', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument( '-vocab', required=True, help='Source sequence to decode (one line per sequence)') parser.add_argument('-output', default='pred.txt', help="""Path to output the predictions (each line will be the decoded sequence""") parser.add_argument('-beam_size', type=int, default=5, help='Beam size') parser.add_argument('-batch_size', type=int, default=30, help='Batch size') parser.add_argument('-n_best', type=int, default=1, help="""If verbose is set, will output the n_best decoded sentences""") parser.add_argument('-no_cuda', action='store_true') opt = parser.parse_args() opt.cuda = not opt.no_cuda # Prepare DataLoader preprocess_data = torch.load(opt.vocab) preprocess_settings = preprocess_data['settings'] test_src_word_insts = read_instances_from_file( opt.src, preprocess_settings.max_word_seq_len, preprocess_settings.keep_case) test_src_insts = convert_instance_to_idx_seq( test_src_word_insts, preprocess_data['dict']['src']) # pdb.set_trace() # (Pdb) print(opt) # Namespace(batch_size=30, beam_size=5, cuda=True, model='trained.chkpt', # n_best=1, no_cuda=False, output='pred.txt', src='data/multi30k/test.en.atok', # vocab='data/multi30k.atok.low.pt') test_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=test_src_insts), num_workers=2, batch_size=opt.batch_size, collate_fn=collate_fn) translator = Translator(opt) with open(opt.output, 'w') as f: for batch in tqdm(test_loader, mininterval=2, desc=' - (Test)', leave=False): all_hyp, all_scores = translator.translate_batch(*batch) for idx_seqs in all_hyp: for idx_seq in idx_seqs: pred_line = ' '.join([ test_loader.dataset.tgt_idx2word[idx] for idx in idx_seq ]) f.write(pred_line + '\n') print('[Info] Finished.')
def prepare_dataloaders(data, opt): # ========= Preparing DataLoader =========# N = data['settings']['n_instances'] train_len = int(N * opt.split) start_idx = int(opt.offset * N) print("Data split: {}".format(opt.split)) print("Training starts at: {} out of {} instances".format(start_idx, N)) if start_idx + train_len < N: train_src_insts = data['src'][start_idx: start_idx + train_len] train_tgt_insts = data['tgt'][start_idx: start_idx + train_len] train_tgt_nums = data['tgt_nums'][start_idx: start_idx + train_len] valid_src_insts = data['src'][start_idx + train_len:] + data['src'][:start_idx] valid_tgt_insts = data['tgt'][start_idx + train_len:] + data['tgt'][:start_idx] valid_tgt_nums = data['tgt_nums'][start_idx + train_len:] + data['tgt_nums'][:start_idx] else: valid_len = N - train_len valid_start_idx = start_idx - valid_len train_src_insts = data['src'][start_idx:] + data['src'][:valid_start_idx] train_tgt_insts = data['tgt'][start_idx:] + data['tgt'][:valid_start_idx] train_tgt_nums = data['tgt_nums'][start_idx:] + data['tgt_nums'][:valid_start_idx] valid_src_insts = data['src'][valid_start_idx: start_idx] valid_tgt_insts = data['tgt'][valid_start_idx: start_idx] valid_tgt_nums = data['tgt_nums'][valid_start_idx: start_idx] train_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=train_src_insts, tgt_insts=train_tgt_insts, tgt_nums=train_tgt_nums, permute_tgt=False), num_workers=2, batch_size=opt.batch_size, # collate_fn=collate_fn, shuffle=True) valid_loader = torch.utils.data.DataLoader( TranslationDataset( src_word2idx=data['dict']['src'], tgt_word2idx=data['dict']['tgt'], src_insts=valid_src_insts, tgt_insts=valid_tgt_insts, tgt_nums=valid_tgt_nums, permute_tgt=False), num_workers=2, batch_size=opt.batch_size) # collate_fn=collate_fn) if opt.bi: train_loader.collate_fn = train_loader.dataset.bidirectional_collate_fn valid_loader.collate_fn = valid_loader.dataset.bidirectional_collate_fn else: train_loader.collate_fn = train_loader.dataset.paired_collate_fn valid_loader.collate_fn = valid_loader.dataset.paired_collate_fn return train_loader, valid_loader
def main(): '''Main Function''' parser = argparse.ArgumentParser(description='reinforcement training') parser.add_argument('-model', required=True, help='Path to pretrained model .pt file') # parser.add_argument('-src', required=True, # help='Source sequence to decode (one line per sequence)') parser.add_argument('-data', required=True, help='preprocessed data file') parser.add_argument( '-original_data', default=config.FORMATTED_DATA, help='original data showing original text and equations') parser.add_argument( '-vocab', default=None, help= 'data file for vocabulary. if not specified (default), use the one in -data' ) parser.add_argument( '-split', type=float, default=0.8, help='proprotion of training data. the rest is test data.') parser.add_argument( '-offset', type=float, default=0, help="determin starting index of training set, for cross validation") parser.add_argument('-save_model', default=None, help="model destination path") parser.add_argument('-beam_size', type=int, default=8, help='Beam size') parser.add_argument('-batch_size', type=int, default=4, help='Batch size') parser.add_argument( '-n_best', type=int, default=8, help="If verbose is set, will output the n_best decoded sentences") parser.add_argument('-no_cuda', action='store_true') parser.add_argument('-epochs', type=int, default=100) parser.add_argument('-teacher_ratio', type=float, default=0., help="probability to allow teacher forcing") parser.add_argument('-permute', action='store_true', help="permute equations for training") opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.reset_num = True # use numbers (not symbols) in output print(opt) # Prepare DataLoader preprocess_data = torch.load(opt.data) if opt.original_data is not None: formatted_data = json.load(open(opt.original_data)) formatted_map = {} for d in formatted_data: formatted_map[d['id']] = d N = preprocess_data['settings']['n_instances'] train_len = int(N * opt.split) start_idx = int(opt.offset * N) print("Data split: {}".format(opt.split)) print("Training starts at: {} out of {} instances".format(start_idx, N)) if start_idx + train_len < N: train_src_insts = preprocess_data['src'][start_idx:start_idx + train_len] train_tgt_insts = preprocess_data['tgt'][start_idx:start_idx + train_len] train_tgt_nums = preprocess_data['tgt_nums'][start_idx:start_idx + train_len] else: valid_len = N - train_len valid_start_idx = start_idx - valid_len train_src_insts = preprocess_data['src'][start_idx:] + preprocess_data[ 'src'][:valid_start_idx] train_tgt_insts = preprocess_data['tgt'][start_idx:] + preprocess_data[ 'tgt'][:valid_start_idx] train_tgt_nums = preprocess_data['tgt_nums'][ start_idx:] + preprocess_data['tgt_nums'][:valid_start_idx] data_loader = torch.utils.data.DataLoader(TranslationDataset( src_word2idx=preprocess_data['dict']['src'], tgt_word2idx=preprocess_data['dict']['tgt'], src_insts=train_src_insts, tgt_insts=train_tgt_insts, tgt_nums=train_tgt_nums, permute_tgt=False), num_workers=1, batch_size=opt.batch_size) # collate_fn=collate_fn) # data_loader.collate_fn = data_loader.dataset.collate_fn data_loader.collate_fn = data_loader.dataset.bidirectional_collate_fn # tgt_insts = preprocess_data['tgt'][:train_len] # block_list = [preprocess_data['dict']['tgt'][UNK_WORD]] translator = Translator(opt) original_max_token_seq_len = translator.model_opt.max_token_seq_len translator.model.train() # set teacher forcing training optimizer optimizer_teacher = Scheduler(optim.Adam(filter( lambda x: x.requires_grad, translator.model.parameters()), betas=(0.9, 0.98), eps=1e-09), alpha=1e-6) # set reinforcement training optimizer optimizer_reinforce = Scheduler(optim.Adam(filter( lambda x: x.requires_grad, translator.model.parameters()), betas=(0.9, 0.98), eps=1e-09), alpha=5e-7) # 1e-8 for epoch in range(opt.epochs): start = time.time() instance_idx = start_idx n_correct = 0 total_loss = 0 optimizer_reinforce.n_current_steps += 1 # for gcl translator.model.encoder.gcl.init_sequence(1) translator.model.encoder.memory_ready = False for batch in tqdm(data_loader, mininterval=2, desc=' - (Train)', leave=True): # batch: (*src_insts, *tgt_insts, *tgt_nums_insts) # print(batch[0]);sys.exit(1) translator.model_opt.max_token_seq_len = 32 # make training managable all_hyp_list, all_score_list = translator.translate_batch( batch[0], batch[1], block_list=[]) # reinforcement training batch_loss, batch_n_correct = train_batch( all_hyp_list, all_score_list, translator, data_loader, preprocess_data, formatted_map, instance_idx, opt) optimizer_reinforce.zero_grad() # # for gcl # memory = translator.model.encoder.gcl.memory # print(memory[-1]) # translator.model.encoder.gcl.init_sequence(1) # translator.model.encoder.gcl.memory = memory # translator.model.encoder.gcl.gcl.meory = memory #for head in translator.model.encoder.gcl.gcl.heads: # head.memory = memory batch_loss.backward() optimizer_reinforce.step_and_update_lr() total_loss += batch_loss.item() n_correct += batch_n_correct instance_idx += opt.batch_size instance_idx = instance_idx % N # if batch_n_correct / opt.batch_size < 0.3: # # teacher forceing training # teacher_train_batch(translator.model, batch, optimizer_teacher, translator.device, # bidirectional=translator.opt.bi) # end of epoch train_acc = n_correct / train_len total_loss = total_loss * opt.batch_size / train_len sys.stdout.write('\n - (Training) ppl: {ppl: 8.5f}, accuracy: {accu:3.3f} %, ' \ 'elapse: {elapse:3.3f} min\n'.format( ppl=math.exp(min(total_loss, 100)), accu=100 * train_acc, elapse=(time.time() - start) / 60)) sys.stdout.flush() model_state_dict = translator.model.state_dict() translator.model_opt.max_token_seq_len = original_max_token_seq_len checkpoint = { 'model': model_state_dict, 'memory': translator.model.encoder.gcl.memory, 'settings': translator.model_opt, 'epoch': epoch } model_name = opt.save_model + '.chkpt' torch.save(checkpoint, model_name)