def query_instances(args, unlabeled_dataset, oracle, active_func="random", labeled_dataset=None): # lc stands for least confident # te stands for token entropy # tte stands for total token entropy assert active_func in [ "random", "longest", "shortest", "lc", "margin", "te", "tte", "dden" ] # lengths represents number of tokens, so BPE should be removed lengths = np.array([ len(remove_special_tok(remove_bpe(s)).split()) for s in unlabeled_dataset ]) # Preparations before querying instances # Reloading network parameters args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available() net, _ = model.get() assert os.path.exists(args.checkpoint) net, src_vocab, tgt_vocab = load_model(args.checkpoint, net) if args.use_cuda: net = net.cuda() # Initialize inference dataset (Unlabeled dataset) infer_dataset = Dataset(unlabeled_dataset, src_vocab) if args.batch_size is not None: infer_dataset.BATCH_SIZE = args.batch_size if args.max_batch_size is not None: infer_dataset.max_batch_size = args.max_batch_size if args.tokens_per_batch is not None: infer_dataset.tokens_per_batch = args.tokens_per_batch infer_dataiter = iter( infer_dataset.get_iterator(shuffle=True, group_by_size=True, include_indices=True)) # Start ranking unlabeled dataset indices = np.arange(len(unlabeled_dataset)) if active_func == "random": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) random.shuffle(result) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", result[idx][0]) print("I:", args.input, args.reference, idx + args.previous_num_sents) elif active_func == "longest": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = [(len( remove_special_tok(remove_bpe( unlabeled_dataset[item[1]])).split(' ')), item[1], item[2]) for item in result] result = sorted(result, key=lambda item: -item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", -result[idx][0]) print("I:", args.input, args.reference, idx + args.previous_num_sents) elif active_func == "shortest": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = [(len( remove_special_tok(remove_bpe( unlabeled_dataset[item[1]])).split(' ')), item[1], item[2]) for item in result] result = sorted(result, key=lambda item: item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", result[idx][0]) print("I:", args.input, args.reference, idx + args.previous_num_sents) indices = indices[np.argsort(lengths[indices])] elif active_func in ["lc", "margin", "te", "tte"]: result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = sorted(result, key=lambda item: item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in range(len(result)): print("S:", unlabeled_dataset[result[idx][1]]) print("H:", result[idx][2]) print("T:", oracle[result[idx][1]]) print("V:", result[idx][0]) print("I:", args.input, args.reference, result[idx][1] + args.previous_num_sents) elif active_func == "dden": punc = [ ".", ",", "?", "!", "'", "<", ">", ":", ";", "(", ")", "{", "}", "[", "]", "-", "..", "...", "...." ] lamb1 = 1 lamb2 = 1 p_u = {} unlabeled_dataset_without_bpe = [] labeled_dataset_without_bpe = [[], []] for s in unlabeled_dataset: unlabeled_dataset_without_bpe.append( remove_special_tok(remove_bpe(s))) for s in labeled_dataset[0]: labeled_dataset_without_bpe[0].append( remove_special_tok(remove_bpe(s))) for s in labeled_dataset[1]: labeled_dataset_without_bpe[1].append( remove_special_tok(remove_bpe(s))) for s in unlabeled_dataset_without_bpe: sentence = s.split() for token in sentence: if token not in punc: if token in p_u.keys(): p_u[token] += 1 else: p_u[token] = 1 total_dden = 0 for token in p_u.keys(): p_u[token] = math.log(p_u[token] + 1) total_dden += p_u[token] for token in p_u.keys(): p_u[token] /= total_dden count_l = {} for s in labeled_dataset_without_bpe[0]: sentence = s.split() for token in sentence: if token not in punc: if token in count_l.keys(): count_l[token] += 1 else: count_l[token] = 1 dden = [] for s in unlabeled_dataset_without_bpe: sentence = s.split() len_for_sentence = 0 sum_for_sentence = 0 for token in sentence: if token not in punc: if token in count_l.keys(): sum_for_sentence += p_u[token] * math.exp( -lamb1 * count_l[token]) else: sum_for_sentence += p_u[token] len_for_sentence += 1 if len_for_sentence != 0: sum_for_sentence /= len_for_sentence dden.append(sum_for_sentence) unlabeled_with_index = [] for i in range((len(unlabeled_dataset))): unlabeled_with_index.append((dden[i], i)) unlabeled_with_index.sort(key=lambda x: x[0], reverse=True) count_batch = {} dden_new = [] for _, i in unlabeled_with_index: sentence = unlabeled_dataset_without_bpe[i].split() len_for_sentence = 0 sum_for_sentence = 0 for token in sentence: if token not in punc: p_tmp = p_u[token] if token in count_batch.keys(): p_tmp = 0 p_tmp *= math.exp(-lamb2 * count_batch[token]) if token in count_l.keys(): p_tmp *= math.exp(-lamb1 * count_l[token]) sum_for_sentence += p_tmp len_for_sentence += 1 for token in sentence: if token not in punc: if token in count_batch.keys(): count_batch[token] += 1 else: count_batch[token] = 1 if len_for_sentence != 0: sum_for_sentence /= len_for_sentence dden_new.append((sum_for_sentence, i)) dden_new.sort(key=lambda x: x[1]) dden_sort = [] for dden_num, _ in dden_new: dden_sort.append(dden_num) ddens = np.array(dden_sort) indices = indices[np.argsort(-ddens)] for idx in indices: print("S:", unlabeled_dataset[idx]) print("T:", oracle[idx]) print("V:", -ddens[idx]) print("I:", args.input, args.reference, idx)
def main(): ''' Usage: python train.py -data_pkl m30k_deen_shr.pkl -log m30k_deen_shr -embs_share_weight -proj_share_weight -label_smoothing -save_model trained -b 256 -warmup 128000 ''' #global TEXT #global babi_train_txt_in, babi_val_txt_in, babi_test_txt_in #global babi_train_txt, babi_val_txt, babi_test_txt parser = argparse.ArgumentParser() parser.add_argument('-data_pkl', default='../data/data_transformer.bin') # all-in-1 data pickle or bpe field parser.add_argument('-train_path', default=None) # bpe encoded data parser.add_argument('-val_path', default=None) # bpe encoded data parser.add_argument('-epoch', type=int, default=10) parser.add_argument('-b', '--batch_size', type=int, default=2048) ## 2048 -- try 512 parser.add_argument('-d_model', type=int, default=512) parser.add_argument('-d_inner_hid', type=int, default=2048) parser.add_argument('-d_k', type=int, default=64) parser.add_argument('-d_v', type=int, default=64) parser.add_argument('-n_head', type=int, default=8) parser.add_argument('-n_layers', type=int, default=6) parser.add_argument('-warmup','--n_warmup_steps', type=int, default=5000) #4000 parser.add_argument('-dropout', type=float, default=0.1) parser.add_argument('-embs_share_weight', action='store_true', default=True) parser.add_argument('-proj_share_weight', action='store_true', default=True) parser.add_argument('-log', default=None) parser.add_argument('-save_model', default='../saved/t2t_model.tar') parser.add_argument('-save_mode', type=str, choices=['all', 'best'], default='best') parser.add_argument('-no_cuda', action='store_true', default=True) parser.add_argument('-label_smoothing', action='store_true') parser.add_argument('-print_to_screen', action='store_true', help='print some values to screen.') parser.add_argument('-load_saved', help='use specific saved model file.', action='store_true') parser.add_argument('-vocab_file', help='path to separate vocab file.',default='../data/data_vocab.bin') opt = parser.parse_args() opt.cuda = not opt.no_cuda opt.d_word_vec = opt.d_model if not opt.log and not opt.save_model: print('No experiment result will be saved.') exit() raise if opt.batch_size < 2048 and opt.n_warmup_steps <= 4000: print('[Warning] The warmup steps may be not enough.\n'\ '(sz_b, warmup) = (2048, 4000) is the official setting.\n'\ 'Using smaller batch w/o longer warmup may cause '\ 'the warmup stage ends with only little data trained.') device = torch.device('cuda' if opt.cuda else 'cpu') print(device,'device') #========= Loading Dataset =========# if True: if all((opt.train_path, opt.val_path)) : training_data, validation_data = prepare_dataloaders_from_bpe_files(opt, device) elif opt.data_pkl: training_data, validation_data = prepare_dataloaders(opt, device) else: raise #training_data, validation_data = prepare_dataloaders_from_bpe_files(opt, device) print(opt) transformer = Transformer( opt.src_vocab_size, opt.trg_vocab_size, src_pad_idx=opt.src_pad_idx, trg_pad_idx=opt.trg_pad_idx, trg_emb_prj_weight_sharing=opt.proj_share_weight, emb_src_trg_weight_sharing=opt.embs_share_weight, d_k=opt.d_k, d_v=opt.d_v, d_model=opt.d_model, d_word_vec=opt.d_word_vec, d_inner=opt.d_inner_hid, n_layers=opt.n_layers, n_head=opt.n_head, dropout=opt.dropout).to(device) print(opt.save_model,'name') if opt.load_saved and os.path.isfile(opt.save_model + '.chkpt'): opt.model = opt.save_model + '.chkpt' transformer = translate.load_model(opt, device) #print('loaded transformer') optimizer = ScheduledOptim( optim.Adam(transformer.parameters(), betas=(0.9, 0.98), eps=1e-09), 2.0, opt.d_model, opt.n_warmup_steps) train(transformer, training_data, validation_data, optimizer, device, opt)
def query_instances(args, unlabeled_dataset, oracle, active_func="random"): # lc stands for least confident # te stands for token entropy # tte stands for total token entropy assert active_func in [ "random", "longest", "shortest", "lc", "margin", "te", "tte" ] # lengths represents number of tokens, so BPE should be removed lengths = np.array([ len(remove_special_tok(remove_bpe(s)).split()) for s in unlabeled_dataset ]) # Preparations before querying instances # Reloading network parameters args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available() net, _ = model.get() assert os.path.exists(args.checkpoint) net, src_vocab, tgt_vocab = load_model(args.checkpoint, net) if args.use_cuda: net = net.cuda() # Initialize inference dataset (Unlabeled dataset) infer_dataset = Dataset(unlabeled_dataset, src_vocab) if args.batch_size is not None: infer_dataset.BATCH_SIZE = args.batch_size if args.max_batch_size is not None: infer_dataset.max_batch_size = args.max_batch_size if args.tokens_per_batch is not None: infer_dataset.tokens_per_batch = args.tokens_per_batch infer_dataiter = iter( infer_dataset.get_iterator(shuffle=True, group_by_size=True, include_indices=True)) # Start ranking unlabeled dataset indices = np.arange(len(unlabeled_dataset)) if active_func == "random": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) random.shuffle(result) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", result[idx][0]) print("I:", args.input, args.reference, idx) elif active_func == "longest": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = [(len( remove_special_tok(remove_bpe( unlabeled_dataset[item[1]])).split(' ')), item[1], item[2]) for item in result] result = sorted(result, key=lambda item: -item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", -result[idx][0]) print("I:", args.input, args.reference, idx) elif active_func == "shortest": result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = [(len( remove_special_tok(remove_bpe( unlabeled_dataset[item[1]])).split(' ')), item[1], item[2]) for item in result] result = sorted(result, key=lambda item: item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in indices: print("S:", unlabeled_dataset[idx]) print("H:", result[idx][2]) print("T:", oracle[idx]) print("V:", result[idx][0]) print("I:", args.input, args.reference, idx) indices = indices[np.argsort(lengths[indices])] elif active_func in ["lc", "margin", "te", "tte"]: result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = sorted(result, key=lambda item: item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') for idx in range(len(result)): print("S:", unlabeled_dataset[result[idx][1]]) print("H:", result[idx][2]) print("T:", oracle[result[idx][1]]) print("V:", result[idx][0]) print("I:", args.input, args.reference, result[idx][1])
def query_instances(args, unlabeled_dataset, active_func="random", tok_budget=None): # lc stands for least confident # te stands for token entropy # tte stands for total token entropy assert active_func in [ "random", "longest", "shortest", "lc", "margin", "te", "tte" ] assert isinstance(tok_budget, int) # lengths represents number of tokens, so BPE should be removed lengths = np.array([ len(remove_special_tok(remove_bpe(s)).split()) for s in unlabeled_dataset ]) total_num = sum(lengths) if total_num < tok_budget: tok_budget = total_num # Preparations before querying instances if active_func in ["lc", "margin", "te", "tte"]: # Reloading network parameters args.use_cuda = (args.no_cuda == False) and torch.cuda.is_available() net, _ = model.get() assert os.path.exists(args.checkpoint) net, src_vocab, tgt_vocab = load_model(args.checkpoint, net) if args.use_cuda: net = net.cuda() # Initialize inference dataset (Unlabeled dataset) infer_dataset = Dataset(unlabeled_dataset, src_vocab) if args.batch_size is not None: infer_dataset.BATCH_SIZE = args.batch_size if args.max_batch_size is not None: infer_dataset.max_batch_size = args.max_batch_size if args.tokens_per_batch is not None: infer_dataset.tokens_per_batch = args.tokens_per_batch infer_dataiter = iter( infer_dataset.get_iterator(shuffle=True, group_by_size=True, include_indices=True)) # Start ranking unlabeled dataset indices = np.arange(len(unlabeled_dataset)) if active_func == "random": np.random.shuffle(indices) elif active_func == "longest": indices = indices[np.argsort(-lengths[indices])] elif active_func == "shortest": indices = indices[np.argsort(lengths[indices])] elif active_func in ["lc", "margin", "te", "tte"]: result = get_scores(args, net, active_func, infer_dataiter, src_vocab, tgt_vocab) result = sorted(result, key=lambda item: item[0]) indices = [item[1] for item in result] indices = np.array(indices).astype('int') include = np.cumsum(lengths[indices]) <= tok_budget include = indices[include] return [unlabeled_dataset[idx] for idx in include], include