def generate_training_samples(): training_sample_dict = {} docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) positive_q_docs_pair = data.load_json(POSITIVE_Q_DOCS) qtype_docs_range = data.load_json(QTYPE_DOC_RANGE) phrase_pairs, _ = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH, DIC_PATH, MAX_TOKENS) print("Obtained %d phrase pairs from %s." % (len(phrase_pairs), TRAIN_QUESTION_ANSWER_PATH)) for question in phrase_pairs: if len(question) == 2 and 'qid' in question[1]: key_weak, _, query_qid = AnalyzeQuestion(question[1]) query_index = docID_dict[query_qid] if key_weak in qtype_docs_range: document_range = (qtype_docs_range[key_weak]['start'], qtype_docs_range[key_weak]['end']) else: document_range = (0, len(docID_dict)) positive_document_list = [ docID_dict[doc] for doc in positive_q_docs_pair[query_qid] ] training_sample_dict[query_qid] = { 'query_index': query_index, 'document_range': document_range, 'positive_document_list': positive_document_list } fw = open('../data/auto_QA_data/retriever_training_samples.json', 'w', encoding="UTF-8") fw.writelines( json.dumps(training_sample_dict, indent=1, ensure_ascii=False)) fw.close() print('Writing retriever_training_samples.json is done!')
def initialize_document_embedding(int_flag=True, w2v=300, file_path=''): device = 'cuda' # Dict: word token -> ID. if not int_flag: emb_dict = data.load_dict(DIC_PATH=DIC_PATH) else: emb_dict = data.load_dict(DIC_PATH=DIC_PATH_INT) ordered_docID_doc_list = data.get_ordered_docID_document( ORDERED_QID_QUESTION_DICT) docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} net = retriever_module.RetrieverModel(emb_size=w2v, dict_size=len(docID_dict), EMBED_FLAG=False, device='cuda').to('cuda') net.cuda() net.zero_grad() # temp_param_dict = get_net_parameter(net) # Get trained wording embeddings. path = file_path net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1) # Add padding vector. doc_embedding_list.append([0.0] * model.EMBEDDING_DIM) doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda() net.document_emb.weight.data = doc_embedding_tensor.clone().detach() # temp_param_dict1 = get_net_parameter(net) MAP_for_queries = 1.0 epoch = 0 isExists = os.path.exists(SAVES_DIR) if not isExists: os.makedirs(SAVES_DIR) # os.makedirs(SAVES_DIR, exist_ok=True) # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries))) torch.save( net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
def initialize_document_embedding(): device = 'cuda' # Dict: word token -> ID. emb_dict = data.load_dict(DIC_PATH=DIC_PATH) ordered_docID_doc_list = data.get_ordered_docID_document( ORDERED_QID_QUESTION_DICT) docID_dict, doc_list = data.get_docID_indices(ordered_docID_doc_list) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} net = retriever_module.RetrieverModel(emb_size=50, dict_size=len(docID_dict), EMBED_FLAG=False, device='cuda').to('cuda') net.cuda() net.zero_grad() # temp_param_dict = get_net_parameter(net) # Get trained wording embeddings. path = '../data/saves/maml_batch8_att=0_newdata2k_1storder_1task/epoch_002_0.394_0.796.dat' net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) doc_embedding_list = get_document_embedding(doc_list, emb_dict, net1) # Add padding vector. doc_embedding_list.append([0.0] * model.EMBEDDING_DIM) doc_embedding_tensor = torch.tensor(doc_embedding_list).cuda() net.document_emb.weight.data = doc_embedding_tensor.clone().detach() # temp_param_dict1 = get_net_parameter(net) MAP_for_queries = 1.0 epoch = 0 os.makedirs(SAVES_DIR, exist_ok=True) # torch.save(net.state_dict(), os.path.join(SAVES_DIR, "initial_epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries))) torch.save( net.state_dict(), os.path.join(SAVES_DIR, "epoch_%03d_%.3f.dat" % (epoch, MAP_for_queries)))
net.cuda() log.info("Model: %s", net) # Load the pre-trained seq2seq model. net.load_state_dict(torch.load(args.load)) # print("Pre-trained network params") # for name, param in net.named_parameters(): # print(name, param.shape) log.info("Model loaded from %s, continue training in MAML-Reptile mode...", args.load) if (args.adaptive): log.info("Using adaptive reward to train the REINFORCE model...") else: log.info("Using 0-1 sparse reward to train the REINFORCE model...") docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} qtype_docs_range = data.load_json(QTYPE_DOC_RANGE) retriever_net = retriever_module.RetrieverModel( emb_size=50, dict_size=len(docID_dict), EMBED_FLAG=args.docembed_grad, device=device).to(device) retriever_net.cuda() log.info("Retriever model: %s", retriever_net) retriever_net.load_state_dict(torch.load(args.retrieverload)) log.info("Retriever model loaded from %s, continue training in RL mode...", args.retrieverload)
def establish_positive_question_documents_pair(MAX_TOKENS): # Dict: word token -> ID. docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} # # List of (question, {question information and answer}) pairs, the training pairs are in format of 1:1. phrase_pairs, emb_dict = data.load_data_MAML(TRAIN_QUESTION_ANSWER_PATH, DIC_PATH, MAX_TOKENS) print("Obtained %d phrase pairs with %d uniq words from %s." % (len(phrase_pairs), len(emb_dict), TRAIN_QUESTION_ANSWER_PATH)) phrase_pairs_944K = data.load_data_MAML(TRAIN_944K_QUESTION_ANSWER_PATH, max_tokens=MAX_TOKENS) print("Obtained %d phrase pairs from %s." % (len(phrase_pairs_944K), TRAIN_944K_QUESTION_ANSWER_PATH)) # Transform token into index in dictionary. train_data = data.encode_phrase_pairs_RLTR(phrase_pairs, emb_dict) # train_data = data.group_train_data(train_data) train_data = data.group_train_data_RLTR(train_data) train_data_944K = data.encode_phrase_pairs_RLTR(phrase_pairs_944K, emb_dict) train_data_944K = data.group_train_data_RLTR_for_support(train_data_944K) dict944k = data.get944k(DICT_944K) print("Reading dict944k from %s is done. %d pairs in dict944k." % (DICT_944K, len(dict944k))) dict944k_weak = data.get944k(DICT_944K_WEAK) print("Reading dict944k_weak from %s is done. %d pairs in dict944k_weak" % (DICT_944K_WEAK, len(dict944k_weak))) metaLearner = metalearner.MetaLearner( samples=5, train_data_support_944K=train_data_944K, dict=dict944k, dict_weak=dict944k_weak, steps=5, weak_flag=True) question_doctments_pair_list = {} idx = 0 for temp_batch in data.iterate_batches(train_data, 1): task = temp_batch[0] if len(task) == 2 and 'qid' in task[1]: # print("Task %s is training..." %(str(task[1]['qid']))) # Establish support set. support_set = metaLearner.establish_support_set( task, metaLearner.steps, metaLearner.weak_flag, metaLearner.train_data_support_944K) documents = [] if len(support_set) > 0: for support_sample in support_set: if len(support_sample) == 2 and 'qid' in support_sample[1]: documents.append(support_sample[1]['qid']) else: print('task %s has no support set!' % (str(task[1]['qid']))) documents.append(task[1]['qid']) question_doctments_pair_list[task[1]['qid']] = documents if idx % 100 == 0: print(idx) idx += 1 else: print('task has no qid or len(task)!=2:') print(task) fw = open('../data/auto_QA_data/retriever_question_documents_pair.json', 'w', encoding="UTF-8") fw.writelines( json.dumps(question_doctments_pair_list, indent=1, ensure_ascii=False)) fw.close() print('Writing retriever_question_documents_pair.json is done!')
def retriever_training(epochs, RETRIEVER_EMBED_FLAG=True, query_embedding=True): ''' One instance of the retriever training samples: query_index = [800000, 0, 2, 100000, 400000, 600000] document_range = [(700000, 944000), (1, 10), (10, 300000), (10, 300000), (300000, 500000), (500000, 700000)] positive_document_list = [[700001-700000, 700002-700000, 900000-700000, 910000-700000, 944000-2-700000], [2, 3], [13009-10, 34555-10, 234-10, 6789-10, 300000-1-10], [11-10, 16-10, 111111-10, 222222-10, 222223-10], [320000-300000, 330000-300000, 340000-300000, 350000-300000, 360000-300000], [600007-500000, 610007-500000, 620007-500000, 630007-500000, 690007-500000]]''' retriever_path = '../data/saves/retriever/initial_epoch_000_1.000.dat' device = 'cuda' learning_rate = 0.01 docID_dict, _ = data.get_docID_indices( data.get_ordered_docID_document(ORDERED_QID_QUESTION_DICT)) # Index -> qid. rev_docID_dict = {id: doc for doc, id in docID_dict.items()} training_samples = data.load_json(TRAINING_SAMPLE_DICT) net = retriever_module.RetrieverModel(emb_size=50, dict_size=len(docID_dict), EMBED_FLAG=RETRIEVER_EMBED_FLAG, device=device).to(device) net.load_state_dict(torch.load(retriever_path)) net.zero_grad() # temp_param_dict = get_net_parameter(net) # retriever_optimizer = optim.SGD(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate) # retriever_optimizer = optim.Adam(filter(lambda p: p.requires_grad, net.parameters()), lr=learning_rate, eps=1e-3) retriever_optimizer = adabound.AdaBound(filter(lambda p: p.requires_grad, net.parameters()), lr=1e-3, final_lr=0.1) # temp_param_dict = get_net_parameter(net) emb_dict = None net1 = None qid_question_pair = {} if query_embedding: emb_dict = data.load_dict(DIC_PATH=DIC_PATH) # Get trained wording embeddings. path = RETRIEVER_PATH net1 = model.PhraseModel(emb_size=model.EMBEDDING_DIM, dict_size=len(emb_dict), hid_size=model.HIDDEN_STATE_SIZE, LSTM_FLAG=True, ATT_FLAG=False, EMBED_FLAG=False).to(device) net1.load_state_dict(torch.load(path)) qid_question_pair = data.get_qid_question_pairs( ORDERED_QID_QUESTION_DICT) max_value = MAX_MAP MAP_for_queries = MAX_MAP for i in range(epochs): print('Epoch %d is training......' % (i)) # count= 0 for key, value in training_samples.items(): retriever_optimizer.zero_grad() net.zero_grad() if query_embedding: if key in qid_question_pair: question_tokens = qid_question_pair[key] else: print("ERROR! NO SUCH QUESTION: %s!" % (str(key))) continue query_tensor = data.get_question_embedding( question_tokens, emb_dict, net1) else: query_tensor = torch.tensor(net.pack_input( value['query_index']).tolist(), requires_grad=False).cuda() document_range = (value['document_range'][0], value['document_range'][1]) logsoftmax_output = net(query_tensor, document_range)[0] logsoftmax_output = logsoftmax_output.cuda() positive_document_list = [ k - value['document_range'][0] for k in value['positive_document_list'] ] possitive_logsoftmax_output = torch.stack( [logsoftmax_output[k] for k in positive_document_list]) loss_policy_v = -possitive_logsoftmax_output.mean() loss_policy_v = loss_policy_v.cuda() loss_policy_v.backward() retriever_optimizer.step() # temp_param_dict = get_net_parameter(net) # if count%100==0: # print(' Epoch %d, %d samples have been trained.' %(i, count)) # count+=1 # Record trained parameters. if i % 1 == 0: MAP_list = [] for j in range(int(len(training_samples) / 40)): random.seed(datetime.now()) key, value = random.choice(list(training_samples.items())) if query_embedding: question_tokens = qid_question_pair[key] query_tensor = data.get_question_embedding( question_tokens, emb_dict, net1) else: query_tensor = torch.tensor(net.pack_input( value['query_index']).tolist(), requires_grad=False).cuda() document_range = (value['document_range'][0], value['document_range'][1]) logsoftmax_output = net(query_tensor, document_range)[0] order = net.calculate_rank(logsoftmax_output.tolist()) positive_document_list = [ k - value['document_range'][0] for k in value['positive_document_list'] ] orders = [order[k] for k in positive_document_list] MAP = mean(orders) MAP_list.append(MAP) MAP_for_queries = mean(MAP_list) print('------------------------------------------------------') print('Epoch %d, MAP_for_queries: %f' % (i, MAP_for_queries)) print('------------------------------------------------------') if MAP_for_queries < max_value: max_value = MAP_for_queries if MAP_for_queries < 500: output_str = "AdaBound" if RETRIEVER_EMBED_FLAG: output_str += "_DocEmbed" if query_embedding: output_str += "_QueryEmbed" torch.save( net.state_dict(), os.path.join( SAVES_DIR, output_str + "_epoch_%03d_%.3f.dat" % (i, MAP_for_queries))) print('Save the state_dict: %s' % (str(i) + ' ' + str(MAP_for_queries))) if MAP_for_queries < 10: break