def load_data(args): train_ann_path = os.path.join(args.data_dir, "train-sec-ann") dev_ann_path = os.path.join(args.data_dir, "dev-sec-ann") pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") passage_text = {} train_pos_id = [] train_answers = [] test_answers = [] test_pos_id = [] test_answers_trivia = [] logger.info("Loading train ann") with open(train_ann_path, 'r', encoding='utf8') as f: # file format: q_id, positive_pid, answers tsvreader = csv.reader(f, delimiter="\t") for row in tsvreader: train_pos_id.append(row[1]) logger.info("Loading dev ann") with open(dev_ann_path, 'r', encoding='utf8') as f: # file format: q_id, positive_pid, answers tsvreader = csv.reader(f, delimiter="\t") for row in tsvreader: test_pos_id.append(row[1]) logger.info( "Finished loading data, pos_id length %d, train answers length %d, test answers length %d", len(train_pos_id), len(train_answers), len(test_answers)) return (train_pos_id, test_pos_id)
def load_data(args): passage_path = os.path.join(args.passage_path, "psgs_w100.tsv") test_qa_path = os.path.join(args.test_qa_path, "nq-test.csv") trivia_test_qa_path = os.path.join(args.trivia_test_qa_path, "trivia-test.csv") train_ann_path = os.path.join(args.data_dir, "train-ann") pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") passage_text = {} train_pos_id = [] train_answers = [] test_answers = [] test_questions = [] test_answers_trivia = [] test_questions_trivia = [] logger.info("Loading train ann") with open(train_ann_path, 'r', encoding='utf8') as f: # file format: q_id, positive_pid, answers tsvreader = csv.reader(f, delimiter="\t") for row in tsvreader: train_pos_id.append(int(row[1])) train_answers.append(eval(row[2])) logger.info("Loading test answers") with open(test_qa_path, "r", encoding="utf-8") as ifile: # file format: question, answers reader = csv.reader(ifile, delimiter='\t') for row in reader: test_answers.append(eval(row[1])) test_questions.append(str(row[0])) logger.info("Loading trivia test answers") with open(trivia_test_qa_path, "r", encoding="utf-8") as ifile: # file format: question, answers reader = csv.reader(ifile, delimiter='\t') for row in reader: test_answers_trivia.append(eval(row[1])) test_questions_trivia.append(str(row[0])) logger.info("Loading passages") with open(passage_path, "r", encoding="utf-8") as tsvfile: reader = csv.reader(tsvfile, delimiter='\t', ) # file format: doc_id, doc_text, title for row in reader: if row[0] != 'id': passage_text[pid2offset[int(row[0])]] = (row[1], row[2]) if args.do_debug and len(passage_text)>10: break logger.info("Finished loading data, pos_id length %d, train answers length %d, test answers length %d", len(train_pos_id), len(train_answers), len(test_answers)) return (passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, test_questions, test_questions_trivia)
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True) dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="test-query", emb_prefix="dev_query_", is_query_inference=True) dev_query_embedding_trivia, dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-test-query", emb_prefix="trivia_dev_query_", is_query_inference=True) real_dev_query_embedding, real_dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-qas-query", emb_prefix="real-dev_query_", is_query_inference=True) real_dev_query_embedding_trivia, real_dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-dev-qas-query", emb_prefix="trivia_real-dev_query_", is_query_inference=True) # passage_embedding == None, if args.split_ann_search == True passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb= not args.split_ann_search) if args.gpu_index: del model # leave gpu for faiss torch.cuda.empty_cache() time.sleep(10) if args.local_rank != -1: dist.barrier() # if None, reloading if passage_embedding2id is None and is_first_worker(): _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=None, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False) logger.info(f"document id size: {passage_embedding2id.shape}") if is_first_worker(): # passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, dev_answers, dev_answers_trivia = preloaded_data if not args.split_ann_search: dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(passage_embedding) logger.info("***** Done ANN Index *****") _, dev_I = index.search(dev_query_embedding, 100) #I: [number of queries, topk] _, dev_I_trivia = index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] logger.info("Start searching for query embedding with length %d", len(query_embedding)) _, I = index.search(query_embedding, top_k) #I: [number of queries, topk] else: _, dev_I_trivia, real_dev_D, real_dev_I = document_split_faiss_index( logger=logger, args=args, top_k_dev=100, top_k=args.topk_training, checkpoint_step=checkpoint_step, dev_query_emb=dev_query_embedding_trivia, train_query_emb=real_dev_query_embedding, emb_prefix="passage_",two_query_set=True, ) dev_D, dev_I, _, I = document_split_faiss_index( logger=logger, args=args, top_k_dev=100, top_k=args.topk_training, checkpoint_step=checkpoint_step, dev_query_emb=dev_query_embedding, train_query_emb=query_embedding, emb_prefix="passage_",two_query_set=True, ) save_trec_file( dev_query_embedding2id,passage_embedding2id,dev_I,dev_D, trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-test_" + checkpoint_step + ".trec")), topN=100 ) save_trec_file( real_dev_query_embedding2id,passage_embedding2id,real_dev_I,real_dev_D, trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-dev_" + checkpoint_step + ".trec")), topN=100 ) # measure ANN mrr top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) real_dev_top_k_hits = validate(passage_text, dev_answers, real_dev_I, real_dev_query_embedding2id, passage_embedding2id) top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I_trivia, dev_query_embedding2id_trivia, passage_embedding2id) query_range_number = I.shape[0] json_dump_dict = { 'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 'dev_top20': real_dev_top_k_hits[19], 'dev_top100': real_dev_top_k_hits[99], 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path, 'n_train_query':query_range_number, } logger.info(json_dump_dict) logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) if args.dual_training: assert args.split_ann_search and args.gpu_index # hard set logger.info("***** Begin ANN Index for dual d2q task *****") top_k = args.topk_training faiss.omp_set_num_threads(args.faiss_omp_num_threads) logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus())) cpu_index = faiss.IndexFlatIP(query_embedding.shape[1]) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(query_embedding) logger.info("***** Done building ANN Index for dual d2q task *****") # train_pos_id : a list, idx -> int pid train_pos_id_inversed = {} for qidx in range(query_embedding2id.shape[0]): qid = query_embedding2id[qidx] pid = int(train_pos_id[qid]) if pid not in train_pos_id_inversed: train_pos_id_inversed[pid]=[qid] else: train_pos_id_inversed[pid].append(qid) possitive_training_passage_id = [ train_pos_id[t] for t in query_embedding2id] # # compatible with MaxP possitive_training_passage_id_embidx=[] possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs possitive_training_passage_id_emb_counts=0 for pos_pid in possitive_training_passage_id: embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten() possitive_training_passage_id_embidx.append(embidx) possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0])) possitive_training_passage_id_emb_counts += embidx.shape[0] possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0) if not args.split_ann_search: D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) else: positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",) assert positive_p_embs.shape[0] == len(possitive_training_passage_id) D, I = index.search(positive_p_embs, args.topk_training_d2q) positive_p_embs = None del positive_p_embs index.reset() logger.info("***** Finish ANN searching for dual d2q task, construct *****") passage_negative_queries = GenerateNegativeQueryID(args, passage_text,train_answers, query_embedding2id, passage_embedding2id[possitive_training_passage_id_embidx], closest_ans=I, training_query_positive_id_inversed=train_pos_id_inversed) logger.info("***** Done ANN searching for negative queries *****") logger.info("***** Construct ANN Triplet *****") prefix = "ann_grouped_training_data_" if args.grouping_ann_data > 0 else "ann_training_data_" train_data_output_path = os.path.join( args.output_dir, prefix + str(output_num)) query_range = list(range(query_range_number)) random.shuffle(query_range) if args.grouping_ann_data > 0 : with open(train_data_output_path, 'w') as f: counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} for query_idx in query_range: query_id = query_embedding2id[query_idx] pos_pid = train_pos_id[query_id] pos_q_group[counting]=int(query_id) pos_d_group[counting]=int(pos_pid) neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]] if args.dual_training: neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]] counting +=1 if counting >= args.grouping_ann_data: jsonline_dict={} jsonline_dict["pos_q_group"]=pos_q_group jsonline_dict["pos_d_group"]=pos_d_group jsonline_dict["neg_D_group"]=neg_D_group if args.dual_training: jsonline_dict["neg_Q_group"]=neg_Q_group f.write(f"{json.dumps(jsonline_dict)}\n") counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} else: # not support dualtraining with open(train_data_output_path, 'w') as f: for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] if not args.dual_training: f.write( "{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) else: # if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed: # continue f.write( "{}\t{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]), ','.join( str(neg_qid) for neg_qid in passage_negative_queries[pos_pid]) ) ) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump(json_dump_dict, f)
def validate(args, closest_docs, dev_scores, query_embedding2id, passage_embedding2id): logger.info('Matching answers in top docs...') scores = dict() count = 0 total = 0 pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") passage_path = os.path.join(args.passage_path, "hotpot_wiki.tsv") idx2title = dict() title2text = dict() with open(passage_path, "r", encoding="utf-8") as tsvfile: reader = csv.reader( tsvfile, delimiter='\t', ) # file format: doc_id, doc_text, title for row in reader: if row[0] != 'id': idx2title[int(row[0])] = row[2] title2text[row[2]] = row[1] with open(args.data_dir + '/hotpot_dev_fullwiki_v1.json', 'r') as fin: dataset = json.load(fin) type_dict = pickle.load(open(args.data_dir + '/dev_type_results.pkl', 'rb')) instances = list() dev_ann_path = os.path.join(args.data_dir, "dev-sec-ann") test_id = list() test_pre_et = list() with open(dev_ann_path, 'r', encoding='utf8') as f: # file format: q_id, positive_pid, answers tsvreader = csv.reader(f, delimiter="\t") for row in tsvreader: test_id.append(row[0]) test_pre_et.append(row[1]) first_hop_ets = pickle.load( open(args.data_dir + '/dev_first_hop_pred.pkl', 'rb')) pred_dict = dict() for query_idx in range(closest_docs.shape[0]): query_id = query_embedding2id[query_idx] qid = test_id[query_id] pre_et = test_pre_et[query_id] if qid not in pred_dict: pred_dict[qid] = {'chain': list(), 'score': list()} pre_score = first_hop_ets[qid]['score'][first_hop_ets[qid] ['pred'].index(pre_et)] all_pred = closest_docs[query_idx] scs = dev_scores[query_idx] for i in range(len(dev_scores[query_idx])): if int(passage_embedding2id[all_pred[i]]) in offset2pid: pred_dict[qid]['chain'].append( pre_et + '#######' + normalize(idx2title[offset2pid[int( passage_embedding2id[all_pred[i]])]])) pred_dict[qid]['score'].append(float(scs[i]) + pre_score) print(len(pred_dict)) sec_hop_pred = dict() for data in dataset: qid = data['_id'] all_pairs = list() supp_set = set() for supp in data['supporting_facts']: title = supp[0] supp_set.add(normalize(title)) #total += 1 supp_set = list(supp_set) if qid in pred_dict: doc_scores = pred_dict[qid]['score'] idxs = sorted(range(len(doc_scores)), key=lambda k: doc_scores[k], reverse=True) for idx in idxs[:250]: all_pairs.append(pred_dict[qid]['chain'][idx]) sec_hop_pred[qid] = all_pairs pickle.dump(sec_hop_pred, open(args.data_dir + '/dev_sec_hop_pred_top250.pkl', 'wb'))
def generate_new_ann(args): #print(test_pos_id.shape) #model = None model = load_model(args, args.model) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") latest_step_num = args.latest_num args.world_size = args.world_size logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-eval-sec") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "dev-sec_" + str(latest_step_num) + "_", emb, is_query_inference=True, load_cache=args.load_cache) #exit() logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False, load_cache=args.load_cache) dim = passage_embedding.shape[1] #print(dev_query_embedding.shape) #print('passage embedding shape: ' + str(passage_embedding.shape)) print('dev embedding shape: ' + str(dev_query_embedding.shape)) faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info('Data indexing completed.') nums = int(dev_query_embedding.shape[0] / 5000) + 1 II = list() sscores = list() for i in range(nums): score, idx = cpu_index.search(dev_query_embedding[i * 5000:(i + 1) * 5000], args.topk) #I: [number of queries, topk] II.append(idx) sscores.append(score) logger.info("Split done %d", i) dev_I = II[0] scores = sscores[0] for i in range(1, nums): dev_I = np.concatenate((dev_I, II[i]), axis=0) scores = np.concatenate((scores, sscores[i]), axis=0) validate(args, dev_I, scores, dev_query_embedding2id, passage_embedding2id)
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "test-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query") dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia) with dev_query_cache_trivia as emb: dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False) logger.info("***** Done passage inference *****") if is_first_worker(): passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk] top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id) logger.info("Start searching for query embedding with length %d", len(query_embedding)) _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk] logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
def validate(args, closest_docs, dev_scores, query_embedding2id, passage_embedding2id): logger.info('Matching answers in top docs...') scores = dict() count = 0 total = 0 pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") passage_path = os.path.join(args.passage_path, "hotpot_wiki.tsv") idx2title = dict() title2text = dict() with open(passage_path, "r", encoding="utf-8") as tsvfile: reader = csv.reader( tsvfile, delimiter='\t', ) # file format: doc_id, doc_text, title for row in reader: if row[0] != 'id': idx2title[int(row[0])] = row[2] title2text[row[2]] = row[1] with open(args.data_dir + '/hotpot_dev_fullwiki_v1.json', 'r') as fin: dataset = json.load(fin) type_dict = pickle.load(open(args.data_dir + '/dev_type_results.pkl', 'rb')) instances = list() first_hop_ets = dict() for query_idx in range(closest_docs.shape[0]): query_id = query_embedding2id[query_idx] all_scores = list() doc_ids = list() all_pred = closest_docs[query_idx] scs = dev_scores[query_idx] for i in range(len(dev_scores[query_idx])): if int(passage_embedding2id[all_pred[i]]) in offset2pid: doc_ids.append(offset2pid[int( passage_embedding2id[all_pred[i]])]) all_scores.append(float(scs[i])) data = dataset[query_id] qid = data['_id'] supp_set = set() for supp in data['supporting_facts']: title = supp[0] supp_set.add(normalize(title)) total += len(supp_set) for ii, d_id in enumerate(doc_ids[:10]): title = normalize(idx2title[d_id]) if title in supp_set: count += 1 first_hop_ets[qid] = { 'score': all_scores, 'pred': [normalize(idx2title[idx]) for idx in doc_ids] } if type_dict[qid] == 'comparison': continue for et in [normalize(idx2title[idx]) for idx in doc_ids]: pre_evidence = ''.join(title2text[et]) qq = data['question'] + ' ' + '[SEP]' + ' ' + et.replace( '_', ' ') + ' ' + '[SEP]' + ' ' + pre_evidence instances.append({ 'dataset': 'hotpot_dev_sec', 'question': qq, 'qid': qid, 'answers': list(), 'first_hop_cts': [et] }) with open(args.data_dir + '/dev_sec_hop_data.json', 'w', encoding='utf-8') as f: json.dump(instances, f, indent=2) pickle.dump(first_hop_ets, open(args.data_dir + '/dev_first_hop_pred.pkl', 'wb')) logger.info("first hop coverage %f", count / total)
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): #passage_text, train_pos_id, train_answers, test_answers, test_pos_id = preloaded_data #print(test_pos_id.shape) model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-sec-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "query_sec_" + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-sec-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "dev_sec_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False, load_cache=False) logger.info("***** Done passage inference *****") if is_first_worker(): train_pos_id, test_pos_id = preloaded_data dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training #num_q = passage_embedding.shape[0] faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info('Data indexing completed.') logger.info("Start searching for query embedding with length %d", len(query_embedding)) II = list() for i in range(15): _, idx = cpu_index.search(query_embedding[i * 5000:(i + 1) * 5000], top_k) #I: [number of queries, topk] II.append(idx) logger.info("Split done %d", i) I = II[0] for i in range(1, 15): I = np.concatenate((I, II[i]), axis=0) logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) #logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = dict() for query_idx in range(I.shape[0]): query_id = query_embedding2id[query_idx] doc_ids = list() doc_ids = [passage_embedding2id[pidx] for pidx in I[query_idx]] neg_docs = list() for doc_id in doc_ids: pos_id = [ int(p_id) for p_id in train_pos_id[query_id].split(',') ] if doc_id in pos_id: continue if doc_id in neg_docs: continue neg_docs.append(doc_id) query_negative_passage[query_id] = neg_docs logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] f.write("{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) _, dev_I = cpu_index.search(dev_query_embedding, 10) #I: [number of queries, topk] top_k_hits = validate(test_pos_id, dev_I, dev_query_embedding2id, passage_embedding2id)