def inference_or_load_embedding(args,logger,model,checkpoint_path,text_data_prefix, emb_prefix, is_query_inference=True,checkonly=False,load_emb=True): # logging.info(f"checkpoint_path {checkpoint_path}") checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') emb_file_pattern = os.path.join(args.output_dir,f'{emb_prefix}{checkpoint_step}__emb_p__data_obj_*.pb') emb_file_lists = glob.glob(emb_file_pattern) emb_file_lists = sorted(emb_file_lists, key=lambda name: int(name.split('_')[-1].replace('.pb',''))) # sort with split num logger.info(f"pattern {emb_file_pattern}\n file lists: {emb_file_lists}") embedding,embedding2id = [None,None] if len(emb_file_lists) > 0: if is_first_worker(): logger.info(f"***** found existing embedding files {emb_file_pattern}, loading... *****") if checkonly: logger.info("check embedding files only, not loading") return embedding,embedding2id embedding = [] embedding2id = [] for emb_file in emb_file_lists: if load_emb: with open(emb_file,'rb') as handle: embedding.append(pickle.load(handle)) embid_file = emb_file.replace('emb_p','embid_p') with open(embid_file,'rb') as handle: embedding2id.append(pickle.load(handle)) if (load_emb and not embedding) or (not embedding2id): logger.error("No data found for checkpoint: ",emb_file_pattern) if load_emb: embedding = np.concatenate(embedding, axis=0) embedding2id = np.concatenate(embedding2id, axis=0) # return embedding,embedding2id # else: if args.local_rank != -1: dist.barrier() # if multi-processing else: logger.info(f"***** inference of {text_data_prefix} *****") query_collection_path = os.path.join(args.data_dir, text_data_prefix) query_cache = EmbeddingCache(query_collection_path) with query_cache as emb: embedding,embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=is_query_inference), emb_prefix + str(checkpoint_step) + "_", emb, is_query_inference=is_query_inference) return embedding,embedding2id
def generate_new_ann( args, output_num, checkpoint_path, latest_step_num, training_query_positive_id=None, dev_query_positive_id=None, ): config, tokenizer, model = load_model(args, checkpoint_path) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) if args.inference: return logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False) logger.info("***** Done passage inference *****") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True) if is_first_worker(): dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr # I: [number of queries, topk] _, dev_I = cpu_index.search(dev_query_embedding, 100) dev_ndcg, num_queries_dev = EvalDevQuery( args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) # Construct new traing set ================================== chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if ( effective_idx == ( chunk_factor - 1)) else ( q_start_idx + queries_per_chunk) query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info( "Chunked {} query from {}".format( len(query_embedding), num_queries)) # I: [number of queries, topk] _, I = cpu_index.search(query_embedding, top_k) effective_q_id = set(query_embedding2id.flatten()) query_negative_passage = GenerateNegativePassaageID( args, query_embedding2id, passage_embedding2id, training_query_positive_id, I, effective_q_id) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] f.write( "{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join( args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev
def generate_new_ann( args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, latest_step_num): config, tokenizer, model = load_model(args, checkpoint_path) dataFound = False if args.end_output_num == 0 and is_first_worker(): dataFound, query_embedding, query_embedding2id, dev_query_embedding, dev_query_embedding2id, passage_embedding, passage_embedding2id = load_init_embeddings(args) if args.local_rank != -1: dist.barrier() if not dataFound: logger.info("***** inference of dev query *****") name_ = None if args.dataset == "dl_test_2019": dev_query_collection_path = os.path.join(args.data_dir, "test-query") name_ = "test_query_" elif args.dataset == "dl_test_2019_1": dev_query_collection_path = os.path.join(args.data_dir, "test-query-1") name_ = "test_query_" elif args.dataset == "dl_test_2019_2": dev_query_collection_path = os.path.join(args.data_dir, "test-query-2") name_ = "test_query_" else: raise Exception('Dataset should be one of {dl_test_2019, dl_test_2019_1, dl_test_2019_2}!!') dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), name_ + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False) logger.info("***** Done passage inference *****") if args.inference: return logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True) else: logger.info("***** Found pre-existing embeddings. So not running inference again. *****") if is_first_worker(): dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr # I: [number of queries, topk] _, dev_I = cpu_index.search(dev_query_embedding, 100) dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) # Construct new traing set ================================== chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else (q_start_idx + queries_per_chunk) query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info( "Chunked {} query from {}".format( len(query_embedding), num_queries)) # I: [number of queries, topk] _, I = cpu_index.search(query_embedding, top_k) effective_q_id = set(query_embedding2id.flatten()) _, dev_I_dist = cpu_index.search(dev_query_embedding, top_k) distrib, samplingDist = getSamplingDist(args, dev_I_dist, dev_query_embedding2id, dev_query_positive_id, passage_embedding2id) sampling_dist_data = {'distrib': distrib, 'samplingDist': samplingDist} dist_output_path = os.path.join(args.output_dir, "dist_" + str(output_num)) with open(dist_output_path, 'wb') as f: pickle.dump(sampling_dist_data, f) query_negative_passage = GenerateNegativePassaageID( args, query_embedding, query_embedding2id, passage_embedding, passage_embedding2id, training_query_positive_id, I, effective_q_id, samplingDist, output_num) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) train_query_cache.open() with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] pos_score = get_BM25_score(query_id, pos_pid, train_query_cache, tokenizer) neg_scores = {} for neg_pid in query_negative_passage[query_id]: neg_scores[neg_pid] = get_BM25_score(query_id, neg_pid, train_query_cache, tokenizer) f.write( "{}\t{}\t{}\n".format( query_id, str(pos_pid)+":"+str(round(pos_score,3)), ','.join( str(neg_pid)+":"+str(round(neg_scores[neg_pid],3)) for neg_pid in query_negative_passage[query_id]))) train_query_cache.close() ndcg_output_path = os.path.join( args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev