def ann_data_gen(args): last_checkpoint = args.last_checkpoint_dir ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir) if is_first_worker(): logger.info("Getting bm25_helper") global bm25_helper bm25_helper = BM25_helper(args) logger.info("Done loading bm25_helper") output_num = ann_no + 1 logger.info("starting output number %d", output_num) if is_first_worker(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) training_positive_id, dev_positive_id = load_positive_ids(args) while args.end_output_num == -1 or output_num <= args.end_output_num: next_checkpoint, latest_step_num = get_latest_checkpoint(args) if args.only_keep_latest_embedding_file: latest_step_num = 0 if next_checkpoint == last_checkpoint: print("Sleeping for 1 hr") time.sleep(3600) else: logger.info("start generate ann data number %d", output_num) logger.info("next checkpoint at " + next_checkpoint) generate_new_ann( args, output_num, next_checkpoint, training_positive_id, dev_positive_id, latest_step_num) if args.inference: break logger.info("finished generating ann data number %d", output_num) output_num += 1 last_checkpoint = next_checkpoint if args.local_rank != -1: dist.barrier()
def ann_data_gen(args): last_checkpoint = args.last_checkpoint_dir ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir) output_num = ann_no + 1 logger.info("starting output number %d", output_num) preloaded_data = None if is_first_worker(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) preloaded_data = load_data(args) while args.end_output_num == -1 or output_num <= args.end_output_num: next_checkpoint, latest_step_num = get_latest_checkpoint(args) logger.info(f"get next_checkpoint {next_checkpoint} latest_step_num {latest_step_num} ") if args.only_keep_latest_embedding_file: latest_step_num = 0 if next_checkpoint == last_checkpoint: time.sleep(60) else: logger.info("start generate ann data number %d", output_num) logger.info("next checkpoint at " + next_checkpoint) generate_new_ann(args, output_num, next_checkpoint, preloaded_data, latest_step_num) logger.warning("process rank: %s, finished generating ann data number %d", args.local_rank, output_num) # logger.info("finished generating ann data number %d", output_num) output_num += 1 last_checkpoint = next_checkpoint if args.local_rank != -1: dist.barrier()
def save_checkpoint(args, model, tokenizer): # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained() if args.do_train and is_first_worker(): # Create output directory if needed if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) logger.info("Saving model checkpoint to %s", args.output_dir) # Save a trained model, configuration and tokenizer using `save_pretrained()`. # They can then be reloaded using `from_pretrained()` if 'fairseq' not in args.train_model_type: model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) else: torch.save(model.state_dict(), os.path.join(output_dir, 'model.pt')) # Good practice: save your training arguments together with the trained model torch.save(args, os.path.join(args.output_dir, "training_args.bin")) dist.barrier()
def evaluation(args, model, tokenizer): # Evaluation results = {} if args.do_eval: model_dir = args.model_name_or_path if args.model_name_or_path else args.output_dir checkpoints = [model_dir] for checkpoint in checkpoints: global_step = checkpoint.split( "-")[-1] if len(checkpoints) > 1 else "" prefix = checkpoint.split( "/")[-1] if checkpoint.find("checkpoint") != -1 else "" model.eval() recall = passage_dist_eval_last(args, model, tokenizer) print('recall@1000: ', recall) reranking_mrr, full_ranking_mrr = passage_dist_eval( args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format( str(reranking_mrr), str(full_ranking_mrr))) dist.barrier() return results
def ann_data_gen(args): if is_first_worker(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) training_positive_id, dev_positive_id = load_positive_ids(args) finished_checkpoint_list = [] while True: all_checkpoint_lists = get_all_checkpoint( args) # include init_checkpoint logger.info("get all the checkpoints list:\n %s", all_checkpoint_lists) for checkpoint_path in all_checkpoint_lists: if checkpoint_path not in finished_checkpoint_list: logger.info( f"inference and eval for checkpoint at {checkpoint_path}") generate_new_ann(args, checkpoint_path) logger.info( f"finished generating ann data number at {checkpoint_path}" ) finished_checkpoint_list.append(checkpoint_path) if args.local_rank != -1: dist.barrier() if args.inference_one_specified_ckpt: break time.sleep(600)
def compute_mrr(D, I, qids, ref_dict): knn_pkl = {"D": D, "I": I} all_knn_list = all_gather(knn_pkl) mrr = 0.0 if is_first_worker(): D_merged = concat_key(all_knn_list, "D", axis=1) I_merged = concat_key(all_knn_list, "I", axis=1) print(D_merged.shape, I_merged.shape) # we pad with negative pids and distance -128 - if they make it to the top we have a problem idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :10] sorted_I = np.take_along_axis(I_merged, idx, axis=1) candidate_dict = {} for i, qid in enumerate(qids): seen_pids = set() if qid not in candidate_dict: candidate_dict[qid] = [0] * 1000 j = 0 for pid in sorted_I[i]: if pid >= 0 and pid not in seen_pids: candidate_dict[qid][j] = pid j += 1 seen_pids.add(pid) allowed, message = quality_checks_qids(ref_dict, candidate_dict) if message != '': print(message) mrr_metrics = compute_metrics(ref_dict, candidate_dict) mrr = mrr_metrics["MRR @10"] print(mrr) return mrr
def compute_mrr_last(D, I, qids, ref_dict, dev_query_positive_id): knn_pkl = {"D": D, "I": I} all_knn_list = all_gather(knn_pkl) mrr = 0.0 final_recall = 0.0 if is_first_worker(): prediction = {} D_merged = concat_key(all_knn_list, "D", axis=1) I_merged = concat_key(all_knn_list, "I", axis=1) print(D_merged.shape, I_merged.shape) # we pad with negative pids and distance -128 - if they make it to the top we have a problem idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :1000] sorted_I = np.take_along_axis(I_merged, idx, axis=1) candidate_dict = {} for i, qid in enumerate(qids): seen_pids = set() if qid not in candidate_dict: prediction[qid] = {} candidate_dict[qid] = [0] * 1000 j = 0 for pid in sorted_I[i]: if pid >= 0 and pid not in seen_pids: candidate_dict[qid][j] = pid prediction[qid][pid] = -(j + 1) #-rank j += 1 seen_pids.add(pid) # allowed, message = quality_checks_qids(ref_dict, candidate_dict) # if message != '': # print(message) # mrr_metrics = compute_metrics(ref_dict, candidate_dict) # mrr = mrr_metrics["MRR @10"] # print(mrr) allowed, message = quality_checks_qids(ref_dict, candidate_dict) if message != '': print(message) mrr_metrics = compute_metrics(ref_dict, candidate_dict) mrr = mrr_metrics["MRR @10"] print(mrr) evaluator = pytrec_eval.RelevanceEvaluator( convert_to_string_id(dev_query_positive_id), {'recall'}) eval_query_cnt = 0 recall = 0 topN = 1000 result = evaluator.evaluate(convert_to_string_id(prediction)) for k in result.keys(): eval_query_cnt += 1 recall += result[k]["recall_" + str(topN)] final_recall = recall / eval_query_cnt print('final_recall: ', final_recall) return mrr, final_recall
def ann_data_gen(args): if is_first_worker(): if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) if not os.path.exists(args.cache_dir): os.makedirs(args.cache_dir) checkpoint_path = args.init_model_dir logger.info(f"inference and eval for checkpoint at {checkpoint_path}") generate_new_ann(args, checkpoint_path) logger.info(f"finished generating ann data number at {checkpoint_path}") if args.local_rank != -1: dist.barrier()
def inference_or_load_embedding(args,logger,model,checkpoint_path,text_data_prefix, emb_prefix, is_query_inference=True,checkonly=False,load_emb=True): # logging.info(f"checkpoint_path {checkpoint_path}") checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') emb_file_pattern = os.path.join(args.output_dir,f'{emb_prefix}{checkpoint_step}__emb_p__data_obj_*.pb') emb_file_lists = glob.glob(emb_file_pattern) emb_file_lists = sorted(emb_file_lists, key=lambda name: int(name.split('_')[-1].replace('.pb',''))) # sort with split num logger.info(f"pattern {emb_file_pattern}\n file lists: {emb_file_lists}") embedding,embedding2id = [None,None] if len(emb_file_lists) > 0: if is_first_worker(): logger.info(f"***** found existing embedding files {emb_file_pattern}, loading... *****") if checkonly: logger.info("check embedding files only, not loading") return embedding,embedding2id embedding = [] embedding2id = [] for emb_file in emb_file_lists: if load_emb: with open(emb_file,'rb') as handle: embedding.append(pickle.load(handle)) embid_file = emb_file.replace('emb_p','embid_p') with open(embid_file,'rb') as handle: embedding2id.append(pickle.load(handle)) if (load_emb and not embedding) or (not embedding2id): logger.error("No data found for checkpoint: ",emb_file_pattern) if load_emb: embedding = np.concatenate(embedding, axis=0) embedding2id = np.concatenate(embedding2id, axis=0) # return embedding,embedding2id # else: # if args.local_rank != -1: # dist.barrier() # if multi-processing else: logger.info(f"***** inference of {text_data_prefix} *****") query_collection_path = os.path.join(args.data_dir, text_data_prefix) query_cache = EmbeddingCache(query_collection_path) with query_cache as emb: embedding,embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=is_query_inference), emb_prefix + str(checkpoint_step) + "_", emb, is_query_inference=is_query_inference) return embedding,embedding2id
def load_model(args): # Prepare GLUE task args.task_name = args.task_name.lower() args.output_mode = "classification" label_list = ["0", "1"] num_labels = len(label_list) # store args if args.local_rank != -1: args.world_size = torch.distributed.get_world_size() args.rank = dist.get_rank() # Load pretrained model and tokenizer if args.local_rank not in [-1, 0]: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab args.model_type = args.model_type.lower() configObj = MSMarcoConfigDict[args.model_type] # tokenizer = configObj.tokenizer_class.from_pretrained( # "bert-base-uncased", # do_lower_case=True, # cache_dir=args.cache_dir if args.cache_dir else None, # ) if is_first_worker(): # Create output directory if needed if not os.path.exists(args.output_dir): os.makedirs(args.output_dir) #if not os.path.exists(args.blob_output_dir): # os.makedirs(args.blob_output_dir) model = configObj.model_class(args) if args.local_rank == 0: torch.distributed.barrier( ) # Make sure only the first process in distributed training will download model & vocab model.to(args.device) return model
def generate_new_ann( args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, latest_step_num): if args.gpu_index: clean_faiss_gpu() if not args.not_load_model_for_inference: config, tokenizer, model = load_model(args, checkpoint_path) checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') def evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="trec2019",split_idx=-1,d2q_eval=False,d2q_qrels=None): if d2q_eval: qrels=d2q_qrels else: if args.data_type ==0 : if not d2q_eval: if test_set== "marcodev": qrels="../data/raw_data/msmarco-docdev-qrels.tsv" elif test_set== "trec2019": qrels="../data/raw_data/2019qrels-docs.txt" elif args.data_type ==1: if test_set == "marcodev": qrels="../data/raw_data/qrels.dev.small.tsv" else: logging.error("wrong data type") exit() trec_path=os.path.join(args.output_dir, trec_prefix + str(checkpoint_step)+".trec") save_trec_file( dev_query_embedding2id,passage_embedding2id,dev_I,dev_D, trec_save_path= trec_path, topN=200) convert_trec_to_MARCO_id( data_type=args.data_type,test_set=test_set, processed_data_dir=args.data_dir, trec_path=trec_path,d2q_reversed_trec_file=d2q_eval) trec_path=trec_path.replace(".trec",".formatted.trec") met = Metric() if split_idx >= 0: split_file_path=qrels+f"{args.dev_split_num}_fold.split_dict" with open(split_file_path,'rb') as f: split=pickle.load(f) else: split=None ndcg10 = met.get_metric(qrels, trec_path, 'ndcg_cut_10',split,split_idx) mrr10 = met.get_mrr(qrels, trec_path, 'mrr_cut_10',split,split_idx) mrr100 = met.get_mrr(qrels, trec_path, 'mrr_cut_100',split,split_idx) logging.info(f" evaluation for {test_set}, trec_file {trec_path}, split_idx {split_idx} \ ndcg_cut_10 : {ndcg10}, \ mrr_cut_10 : {mrr10}, \ mrr_cut_100 : {mrr100}" ) return ndcg10 # Inference if args.data_type==0: # TREC DL 2019 evalset trec2019_query_embedding, trec2019_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-query", emb_prefix="dev_query_", is_query_inference=True)# it's trec-dl testset actually dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path, text_data_prefix="real-dev-query", emb_prefix="real-dev_query_", is_query_inference=True) query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True) if not args.split_ann_search: # merge all passage passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False) else: # keep id only _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False) # FirstP shape, # passage_embedding: [[vec_0], [vec_1], [vec_2], [vec_3] ...], # passage_embedding2id: [id0, id1, id2, id3, ...] # MaxP shape, # passage_embedding: [[vec_0_0], [vec_0_1],[vec_0_2],[vec_0_3],[vec_1_0],[vec_1_1] ...], # passage_embedding2id: [id0, id0, id0, id0, id1, id1 ...] if args.gpu_index: del model # leave gpu for faiss torch.cuda.empty_cache() time.sleep(10) if args.inference: return if is_first_worker(): # Construct new traing subset chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if ( effective_idx == ( chunk_factor - 1)) else ( q_start_idx + queries_per_chunk) logger.info( "Chunked {} query from {}".format( len(query_embedding[q_start_idx:q_end_idx]), num_queries)) if not args.split_ann_search: dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(args.faiss_omp_num_threads) cpu_index = faiss.IndexFlatIP(dim) logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus())) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(passage_embedding) # for measure ANN mrr logger.info("search dev query") dev_D, dev_I = index.search(dev_query_embedding, 100) # I: [number of queries, topk] logger.info("finish") logger.info("search train query") D, I = index.search(query_embedding[q_start_idx:q_end_idx], top_k) # I: [number of queries, topk] logger.info("finish") index.reset() else: if args.data_type==0: trec2019_D, trec2019_I, _, _ = document_split_faiss_index( logger=logger, args=args, checkpoint_step=checkpoint_step, top_k_dev = 200, top_k = args.topk_training, dev_query_emb=trec2019_query_embedding, train_query_emb=None, emb_prefix="passage_",two_query_set=False, ) dev_D, dev_I, D, I = document_split_faiss_index( logger=logger, args=args, checkpoint_step=checkpoint_step, top_k_dev = 200, top_k = args.topk_training, dev_query_emb=dev_query_embedding, train_query_emb=query_embedding[q_start_idx:q_end_idx], emb_prefix="passage_") logger.info("***** seperately process indexing *****") logger.info("***** Done ANN Index *****") # dev_ndcg, num_queries_dev = EvalDevQuery( # args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) logger.info("***** Begin evaluation *****") eval_dict_todump={'checkpoint': checkpoint_path} if args.data_type==0: trec2019_ndcg = evaluation(trec2019_query_embedding2id,passage_embedding2id,trec2019_I,trec2019_D,trec_prefix="dev_query_",test_set="trec2019") if args.dev_split_num > 0: marcodev_ndcg = 0.0 for i in range(args.dev_split_num): ndcg_10_dev_split_i = evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="marcodev",split_idx=i) if i != args.testing_split_idx: marcodev_ndcg += ndcg_10_dev_split_i eval_dict_todump[f'marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i logger.info(f"average marco dev { marcodev_ndcg /(args.dev_split_num -1)}") else: marcodev_ndcg = evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="marcodev",split_idx=-1) eval_dict_todump['marcodev_ndcg']=marcodev_ndcg query_range_number = I.shape[0] if args.save_training_query_trec: logger.info("***** Save the ANN searching for negative passages in trec file format *****") trec_output_path=os.path.join(args.output_dir, "ann_training_query_retrieval_" + str(output_num)+".trec") save_trec_file(query_embedding2id[q_start_idx:q_end_idx],passage_embedding2id,I,D,trec_output_path,topN=args.topk_training) effective_q_id = set(query_embedding2id[q_start_idx:q_end_idx].flatten()) query_negative_passage = GenerateNegativePassaageID( args, query_embedding2id[q_start_idx:q_end_idx], passage_embedding2id, training_query_positive_id, I, effective_q_id) logger.info("***** Done ANN searching for negative passages *****") if args.d2q_task_evaluation and args.d2q_task_marco_dev_qrels is not None: with open(os.path.join(args.data_dir,'pid2offset.pickle'),'rb') as f: pid2offset = pickle.load(f) real_dev_ANCE_ids=[] with open(args.d2q_task_marco_dev_qrels+f"{args.dev_split_num}_fold.split_dict","rb") as f: dev_d2q_split_dict=pickle.load(f) for i in dev_d2q_split_dict: for stringdocid in dev_d2q_split_dict[i]: if args.data_type==0: real_dev_ANCE_ids.append(pid2offset[int(stringdocid[1:])]) else: real_dev_ANCE_ids.append(pid2offset[int(stringdocid)]) real_dev_ANCE_ids = np.array(real_dev_ANCE_ids).flatten() real_dev_possitive_training_passage_id_embidx=[] for dev_pos_pid in real_dev_ANCE_ids: embidx=np.asarray(np.where(passage_embedding2id==dev_pos_pid)).flatten() real_dev_possitive_training_passage_id_embidx.append(embidx) # possitive_training_passage_id_to_subset_embidx[int(dev_pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0])) # possitive_training_passage_id_emb_counts += embidx.shape[0] real_dev_possitive_training_passage_id_embidx=np.concatenate(real_dev_possitive_training_passage_id_embidx,axis=0) del pid2offset if not args.split_ann_search: real_dev_positive_p_embs = passage_embedding[real_dev_possitive_training_passage_id_embidx] else: real_dev_positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,real_dev_possitive_training_passage_id_embidx,emb_prefix="passage_",) logger.info("***** d2q task evaluation *****") cpu_index = faiss.IndexFlatIP(dev_query_embedding.shape[1]) index = cpu_index # index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(dev_query_embedding) real_dev_d2q_D, real_dev_d2q_I = index.search(real_dev_positive_p_embs, 200) if args.dev_split_num > 0: d2q_marcodev_ndcg = 0.0 for i in range(args.dev_split_num): d2q_ndcg_10_dev_split_i = evaluation( real_dev_ANCE_ids,dev_query_embedding2id ,real_dev_d2q_I,real_dev_d2q_D, trec_prefix="d2q-dual-task_real-dev_query_",test_set="marcodev",split_idx=i,d2q_eval=True,d2q_qrels=args.d2q_task_marco_dev_qrels) if i != args.testing_split_idx: d2q_marcodev_ndcg += d2q_ndcg_10_dev_split_i eval_dict_todump[f'd2q_marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i logger.info(f"average marco dev d2q task { d2q_marcodev_ndcg /(args.dev_split_num -1)}") else: d2q_marcodev_ndcg = evaluation(real_dev_ANCE_ids,dev_query_embedding2id ,real_dev_d2q_I,real_dev_d2q_D, trec_prefix="d2q-dual-task_real-dev_query_",test_set="marcodev",split_idx=-1,d2q_eval=True,d2q_qrels=args.d2q_task_marco_dev_qrels) eval_dict_todump['d2q_marcodev_ndcg'] = d2q_marcodev_ndcg if args.dual_training: # do this before completely truncating the query embedding logger.info("***** Do ANN Index for dual d2q task *****") top_k = args.topk_training faiss.omp_set_num_threads(args.faiss_omp_num_threads) logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus())) cpu_index = faiss.IndexFlatIP(query_embedding.shape[1]) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(query_embedding) logger.info("***** Done building ANN Index for dual d2q task *****") logger.info("***** use ANCE id to construct positive passage embedding index *****") training_query_positive_id_inversed = {} # {v:k for k,v in training_query_positive_id.items()} # doc_id : query_id for k in training_query_positive_id: pos_pid=training_query_positive_id[k] if pos_pid not in training_query_positive_id_inversed: training_query_positive_id_inversed[pos_pid]=[k] else: training_query_positive_id_inversed[pos_pid].append(k) possitive_training_passage_id = [ training_query_positive_id[t] for t in query_embedding2id[q_start_idx:q_end_idx]] # effective_p_id = set(possitive_training_passage_id) if "BM25_retrieval" == args.query_likelihood_strategy: passage_negative_queries = {} logger.info("***** loading negative queries from BM25 search result *****") with open(args.bm25_top_d2q_path,"r") as f: for line in f: pid,qid,rank = line.strip().split("\t") pid = int(pid) qid = int(qid) if (pid in effective_p_id) and (qid not in training_query_positive_id_inversed[pid]): if pid not in passage_negative_queries: passage_negative_queries[pid]=[qid] elif len(passage_negative_queries[pid]) < args.topk_training_d2q: passage_negative_queries[pid].append(qid) logger.info(f"***** shuffle and pick {args.negative_sample} negative queries *****") for pid in passage_negative_queries: random.shuffle(passage_negative_queries[pid]) passage_negative_queries[pid]=passage_negative_queries[pid][:args.negative_sample] else: # compatible with MaxP possitive_training_passage_id_embidx=[] possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs possitive_training_passage_id_emb_counts=0 for pos_pid in possitive_training_passage_id: embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten() possitive_training_passage_id_embidx.append(embidx) possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0])) possitive_training_passage_id_emb_counts += embidx.shape[0] possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0) if not args.split_ann_search: D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) else: if args.query_likelihood_strategy == "positive_doc": positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",) assert positive_p_embs.shape[0] == len(possitive_training_passage_id) D, I = index.search(positive_p_embs, args.topk_training_d2q) positive_p_embs = None del positive_p_embs index.reset() logger.info("***** Finish ANN searching for dual d2q task, construct *****") passage_negative_queries = GenerateNegativeQueryID( args, possitive_training_passage_id, query_embedding2id, training_query_positive_id_inversed, I, effective_p_id, pid2pos_pas_embs_idxs = possitive_training_passage_id_to_subset_embidx, Scores_nearest_neighbor=D if "multi_chunk" in args.model_type else None) logger.info("***** Done ANN searching for negative queries *****") query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info("***** Construct ANN Triplet *****") prefix = "ann_grouped_training_data_" if args.grouping_ann_data > 0 else "ann_training_data_" train_data_output_path = os.path.join( args.output_dir, prefix + str(output_num)) if args.grouping_ann_data > 0 : with open(train_data_output_path, 'w') as f: query_range = list(range(query_range_number)) random.shuffle(query_range) counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] pos_q_group[counting]=int(query_id) pos_d_group[counting]=int(pos_pid) neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]] if args.dual_training: neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]] counting +=1 if counting >= args.grouping_ann_data: jsonline_dict={} jsonline_dict["pos_q_group"]=pos_q_group jsonline_dict["pos_d_group"]=pos_d_group jsonline_dict["neg_D_group"]=neg_D_group if args.dual_training: jsonline_dict["neg_Q_group"]=neg_Q_group f.write(f"{json.dumps(jsonline_dict)}\n") counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} else: with open(train_data_output_path, 'w') as f: query_range = list(range(query_range_number)) random.shuffle(query_range) # old version implementation for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] if not args.dual_training: f.write( "{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) else: if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed: continue f.write( "{}\t{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]), ','.join( str(neg_qid) for neg_qid in passage_negative_queries[pos_pid]) ) ) ndcg_output_path = os.path.join( args.output_dir, "ann_ndcg_" + str(output_num)) if args.data_type==0: eval_dict_todump['trec2019_ndcg']=trec2019_ndcg with open(ndcg_output_path, 'w') as f: json.dump(eval_dict_todump, f) return None #dev_ndcg, num_queries_dev
def generate_new_ann( args, output_num, checkpoint_path, latest_step_num, training_query_positive_id=None, dev_query_positive_id=None, ): config, tokenizer, model = load_model(args, checkpoint_path) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) if args.inference: return logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False) logger.info("***** Done passage inference *****") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True) if is_first_worker(): dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr # I: [number of queries, topk] _, dev_I = cpu_index.search(dev_query_embedding, 100) dev_ndcg, num_queries_dev = EvalDevQuery( args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) # Construct new traing set ================================== chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if ( effective_idx == ( chunk_factor - 1)) else ( q_start_idx + queries_per_chunk) query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info( "Chunked {} query from {}".format( len(query_embedding), num_queries)) # I: [number of queries, topk] _, I = cpu_index.search(query_embedding, top_k) effective_q_id = set(query_embedding2id.flatten()) query_negative_passage = GenerateNegativePassaageID( args, query_embedding2id, passage_embedding2id, training_query_positive_id, I, effective_q_id) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] f.write( "{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join( args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model # path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 save_no = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0 and global_step % args.save_steps < args.save_steps / 20: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) with open(ann_path, 'r') as f: ann_training_data = f.readlines() dev_ndcg = ndcg_json['ndcg'] ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step tb_writer.add_scalar("dev_ndcg", dev_ndcg, ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) batch = tuple(t.to(args.device) for t in batch) step += 1 if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) save_no += 1 if save_no > 1: ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir) while (ann_no == last_ann_no): print("Waiting for new ann_data. Sleeping for 1hr!!") time.sleep(3600) ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir) dist.barrier() if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max( 1, args.n_gpu) #nll loss for query real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * ( torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer = get_optimizer( args, model, weight_decay=args.weight_decay, ) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) tr_loss = 0.0 model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None dev_ndcg = 0 step = 0 iter_count = 0 scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) global_step = 0 if args.model_name_or_path != "bert-base-uncased": saved_state = load_states_from_checkpoint(args.model_name_or_path) global_step = _load_saved_state(model, optimizer, scheduler, saved_state) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) print(args.num_epoch) #step = global_step print(step, args.max_steps, global_step) global_step = 0 while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: if args.num_epoch == 0: #print('yes') # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) #print(ann_path) #print(ann_no) #print(ndcg_json) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) time.sleep(180) with open(ann_path, 'r') as f: #print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) #ann_checkpoint_path = ndcg_json['checkpoint'] #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info("Total ann queries: %d", len(ann_training_data)) if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data)) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no) #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no) #if 'top20_trivia' in ndcg_json: # tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no) # tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no elif step == 0: train_data_path = os.path.join(args.data_dir, "train-data") with open(train_data_path, 'r') as f: training_data = f.readlines() if args.triplet: train_dataset = StreamingDataset( training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size) else: train_dataset = StreamingDataset( training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) all_batch = [b for b in train_dataloader] logger.info("Total batch count: %d", len(all_batch)) train_dataloader_iter = iter(train_dataloader) try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") if args.num_epoch != 0: iter_count += 1 if is_first_worker(): tb_writer.add_scalar("epoch", iter_count - 1, global_step - 1) tb_writer.add_scalar("epoch", iter_count, global_step) #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache) #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia") #if is_first_worker(): # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step) # tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step) # tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step) ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir) if ann_path is not None: with open(ann_path, 'r') as f: print(ann_path) ann_training_data = f.readlines() logger.info("Training data line count: %d", len(ann_training_data)) ann_training_data = [ l for l in ann_training_data if len(l.split('\t')[2].split(',')) > 1 ] logger.info("Filtered training data line count: %d", len(ann_training_data)) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn(args, query_cache, passage_cache)) train_dataloader = DataLoader( train_dataset, batch_size=args.train_batch_size * 2) train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) dist.barrier() if args.num_epoch != 0 and iter_count > args.num_epoch: break step += 1 if args.triplet: loss = triplet_fwd_pass(args, model, batch) else: loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch) if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} loss_scalar = tr_loss / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar tr_loss = 0 if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: _save_checkpoint(args, model, optimizer, scheduler, global_step) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): #passage_text, train_pos_id, train_answers, test_answers, test_pos_id = preloaded_data #print(test_pos_id.shape) model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-sec-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "query_sec_" + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "dev-sec-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=True), "dev_sec_query_" + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc( args, model, GetProcessingFn(args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False, load_cache=False) logger.info("***** Done passage inference *****") if is_first_worker(): train_pos_id, test_pos_id = preloaded_data dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training #num_q = passage_embedding.shape[0] faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info('Data indexing completed.') logger.info("Start searching for query embedding with length %d", len(query_embedding)) II = list() for i in range(15): _, idx = cpu_index.search(query_embedding[i * 5000:(i + 1) * 5000], top_k) #I: [number of queries, topk] II.append(idx) logger.info("Split done %d", i) I = II[0] for i in range(1, 15): I = np.concatenate((I, II[i]), axis=0) logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) #logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = dict() for query_idx in range(I.shape[0]): query_id = query_embedding2id[query_idx] doc_ids = list() doc_ids = [passage_embedding2id[pidx] for pidx in I[query_idx]] neg_docs = list() for doc_id in doc_ids: pos_id = [ int(p_id) for p_id in train_pos_id[query_id].split(',') ] if doc_id in pos_id: continue if doc_id in neg_docs: continue neg_docs.append(doc_id) query_negative_passage[query_id] = neg_docs logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] f.write("{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) _, dev_I = cpu_index.search(dev_query_embedding, 10) #I: [number of queries, topk] top_k_hits = validate(test_pos_id, dev_I, dev_query_embedding2id, passage_embedding2id)
def train(args, model, tokenizer, query_cache, passage_cache): """ Train the model """ logger.info("Training/evaluation parameters %s", args) tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) def optimizer_to(optim, device): for param in optim.state.values(): # Not sure there are any global tensors in the state dict if isinstance(param, torch.Tensor): param.data = param.data.to(device) if param._grad is not None: param._grad.data = param._grad.data.to(device) elif isinstance(param, dict): for subparam in param.values(): if isinstance(subparam, torch.Tensor): subparam.data = subparam.data.to(device) if subparam._grad is not None: subparam._grad.data = subparam._grad.data.to( device) torch.cuda.empty_cache() # Check if saved optimizer or scheduler states exist if os.path.isfile( os.path.join(args.model_name_or_path, "optimizer.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"), map_location='cpu')) optimizer_to(optimizer, args.device) model.to(args.device) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train logger.info("***** Running training *****") logger.info(" Max steps = %d", args.max_steps) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) global_step = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model # path if "-" in args.model_name_or_path: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) else: global_step = 0 logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from global step %d", global_step) is_hypersphere_training = (args.hyper_align_weight > 0 or args.hyper_unif_weight > 0) if is_hypersphere_training: logger.info( f"training with hypersphere property regularization, align weight {args.hyper_align_weight}, unif weight {args.hyper_unif_weight}" ) if not args.dual_training: args.dual_loss_weight = 0.0 tr_loss_dict = {} model.zero_grad() model.train() set_seed(args) # Added here for reproductibility last_ann_no = -1 train_dataloader = None train_dataloader_iter = None # dev_ndcg = 0 step = 0 if args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=args.max_steps) if os.path.isfile(os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states scheduler.load_state_dict( torch.load( os.path.join(args.model_name_or_path, "scheduler.pt"))) while global_step < args.max_steps: if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0: # check if new ann training data is availabe ann_no, ann_path, ndcg_json = get_latest_ann_data( args.ann_dir, is_grouped=(args.grouping_ann_data > 0)) if ann_path is not None and ann_no != last_ann_no: logger.info("Training on new add data at %s", ann_path) time.sleep(30) # wait until transmission finished with open(ann_path, 'r') as f: ann_training_data = f.readlines() # marcodev_ndcg = ndcg_json['marcodev_ndcg'] logging.info(f"loading:\n{ndcg_json}") ann_checkpoint_path = ndcg_json['checkpoint'] ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path) aligned_size = (len(ann_training_data) // args.world_size) * args.world_size ann_training_data = ann_training_data[:aligned_size] logger.info( "Total ann queries: %d", len(ann_training_data) if args.grouping_ann_data < 0 else len(ann_training_data) * args.grouping_ann_data) if args.grouping_ann_data > 0: if args.polling_loaded_data_batch_from_group: train_dataset = StreamingDataset( ann_training_data, GetGroupedTrainingDataProcessingFn_polling( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetGroupedTrainingDataProcessingFn_origin( args, query_cache, passage_cache)) else: if not args.dual_training: if args.triplet: train_dataset = StreamingDataset( ann_training_data, GetTripletTrainingDataProcessingFn( args, query_cache, passage_cache)) else: train_dataset = StreamingDataset( ann_training_data, GetTrainingDataProcessingFn( args, query_cache, passage_cache)) else: # return quadruplet train_dataset = StreamingDataset( ann_training_data, GetQuadrapuletTrainingDataProcessingFn( args, query_cache, passage_cache)) train_dataloader = DataLoader(train_dataset, batch_size=args.train_batch_size) train_dataloader_iter = iter(train_dataloader) # re-warmup if not args.single_warmup: scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=len(ann_training_data) if args.grouping_ann_data < 0 else len(ann_training_data) * args.grouping_ann_data) if args.local_rank != -1: dist.barrier() if is_first_worker(): # add ndcg at checkpoint step used instead of current step for key in ndcg_json: if "marcodev" in key: tb_writer.add_scalar(key, ndcg_json[key], ann_checkpoint_no) if 'trec2019_ndcg' in ndcg_json: tb_writer.add_scalar("trec2019_ndcg", ndcg_json['trec2019_ndcg'], ann_checkpoint_no) if last_ann_no != -1: tb_writer.add_scalar("epoch", last_ann_no, global_step - 1) tb_writer.add_scalar("epoch", ann_no, global_step) last_ann_no = ann_no try: batch = next(train_dataloader_iter) except StopIteration: logger.info("Finished iterating current dataset, begin reiterate") train_dataloader_iter = iter(train_dataloader) batch = next(train_dataloader_iter) # original way if args.grouping_ann_data <= 0: batch = tuple(t.to(args.device) for t in batch) if args.triplet: inputs = { "query_ids": batch[0].long(), "attention_mask_q": batch[1].long(), "input_ids_a": batch[3].long(), "attention_mask_a": batch[4].long(), "input_ids_b": batch[6].long(), "attention_mask_b": batch[7].long() } if args.dual_training: inputs["neg_query_ids"] = batch[9].long() inputs["attention_mask_neg_query"] = batch[10].long() inputs["prime_loss_weight"] = args.prime_loss_weight inputs["dual_loss_weight"] = args.dual_loss_weight else: inputs = { "input_ids_a": batch[0].long(), "attention_mask_a": batch[1].long(), "input_ids_b": batch[3].long(), "attention_mask_b": batch[4].long(), "labels": batch[6] } else: # the default collate_fn will convert item["q_pos"] into batch format ... I guess inputs = { "query_ids": batch["q_pos"][0].to(args.device).long(), "attention_mask_q": batch["q_pos"][1].to(args.device).long(), "input_ids_a": batch["d_pos"][0].to(args.device).long(), "attention_mask_a": batch["d_pos"][1].to(args.device).long(), "input_ids_b": batch["d_neg"][0].to(args.device).long(), "attention_mask_b": batch["d_neg"][1].to(args.device).long(), } if args.dual_training: inputs["neg_query_ids"] = batch["q_neg"][0].to( args.device).long() inputs["attention_mask_neg_query"] = batch["q_neg"][1].to( args.device).long() inputs["prime_loss_weight"] = args.prime_loss_weight inputs["dual_loss_weight"] = args.dual_loss_weight inputs["temperature"] = args.temperature inputs["loss_objective"] = args.loss_objective_function if is_hypersphere_training: inputs["alignment_weight"] = args.hyper_align_weight inputs["uniformity_weight"] = args.hyper_unif_weight step += 1 if args.local_rank != -1: # sync gradients only at gradient accumulation step if step % args.gradient_accumulation_steps == 0: outputs = model(**inputs) else: with model.no_sync(): outputs = model(**inputs) else: outputs = model(**inputs) # model outputs are always tuple in transformers (see doc) loss = outputs[0] loss_item_dict = outputs[1] if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training for k in loss_item_dict: loss_item_dict[k] = loss_item_dict[k].mean() if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps for k in loss_item_dict: loss_item_dict[ k] = loss_item_dict[k] / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if args.local_rank != -1: if step % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() else: loss.backward() def incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss): for k in loss_item_dict: if k not in tr_loss_dict: tr_loss_dict[k] = loss_item_dict[k].item() else: tr_loss_dict[k] += loss_item_dict[k].item() if "loss_total" not in tr_loss_dict: tr_loss_dict["loss_total"] = total_loss.item() else: tr_loss_dict["loss_total"] += total_loss.item() return tr_loss_dict tr_loss_dict = incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss=loss) if step % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar for k in tr_loss_dict: logs[k] = tr_loss_dict[k] / args.logging_steps tr_loss_dict = {} if is_first_worker(): for key, value in logs.items(): tb_writer.add_scalar(key, value, global_step) logger.info(json.dumps({**logs, **{"step": global_step}})) if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join(args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step
def train(args, model, tokenizer, f, train_fn): """ Train the model """ tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs print('????t_total', t_total) # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) # if getattr_recursive(model, "roberta.encoder.layer") is not None: # for layer in model.roberta.encoder.layer: # optimizer_grouped_parameters.append({"params": layer.parameters()}) # for p in layer.parameters(): # layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) if args.scheduler.lower() == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) elif args.scheduler.lower() == "cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception( "Scheduler {0} not recognized! Can only be linear or cosine". format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( args.expected_train_size // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 tr_acc, logging_acc = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility #print('???',args.local_rank) #assert 1==0, "?????" for m_epoch in train_iterator: f.seek(0) sds = StreamingDataset(f, train_fn) epoch_iterator = DataLoader(sds, batch_size=args.per_gpu_train_batch_size, num_workers=1) for step, batch in tqdm(enumerate(epoch_iterator), desc="Iteration", disable=args.local_rank not in [-1, 0]): #assert 1==0, "?????" # Skip past any already trained steps if resuming training #assert 1==0, steps_trained_in_current_epoch if not args.reset_iter: if steps_trained_in_current_epoch > 0: steps_trained_in_current_epoch -= 1 continue model.train() batch = tuple(t.to(args.device).long() for t in batch) # print('???',*batch) # assert 1==0, "!!!!!" if (step + 1) % args.gradient_accumulation_steps == 0: outputs = model(*batch) else: with model.no_sync(): # print('???',*batch) # assert 1==0 outputs = model(*batch) # model outputs are always tuple in transformers (see doc) loss = outputs[0] acc = outputs[1] #print('???',acc) if is_first_worker(): print(*batch) assert 1 == 0 if args.n_gpu > 1: loss = loss.mean( ) # mean() to average on multi-gpu parallel training acc = acc.float().mean() #print('???',acc) if args.gradient_accumulation_steps > 1: loss = loss / args.gradient_accumulation_steps acc = acc / args.gradient_accumulation_steps if args.fp16: with amp.scale_loss(loss, optimizer) as scaled_loss: scaled_loss.backward() else: if (step + 1) % args.gradient_accumulation_steps == 0: loss.backward() else: with model.no_sync(): loss.backward() tr_loss += loss.item() tr_acc += acc.item() if (step + 1) % args.gradient_accumulation_steps == 0: if args.fp16: torch.nn.utils.clip_grad_norm_( amp.master_params(optimizer), args.max_grad_norm) else: torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm) optimizer.step() scheduler.step() # Update learning rate schedule model.zero_grad() global_step += 1 if is_first_worker( ) and args.save_steps > 0 and global_step % args.save_steps == 0: # Save model checkpoint output_dir = os.path.join( args.output_dir, "checkpoint-{}".format(global_step)) if not os.path.exists(output_dir): os.makedirs(output_dir) if 'fairseq' not in args.train_model_type: model_to_save = ( model.module if hasattr(model, "module") else model ) # Take care of distributed/parallel training model_to_save.save_pretrained(output_dir) tokenizer.save_pretrained(output_dir) else: torch.save(model.state_dict(), os.path.join(output_dir, 'model.pt')) torch.save(args, os.path.join(output_dir, "training_args.bin")) logger.info("Saving model checkpoint to %s", output_dir) torch.save(optimizer.state_dict(), os.path.join(output_dir, "optimizer.pt")) torch.save(scheduler.state_dict(), os.path.join(output_dir, "scheduler.pt")) logger.info("Saving optimizer and scheduler states to %s", output_dir) dist.barrier() if args.logging_steps > 0 and global_step % args.logging_steps == 0: logs = {} if args.evaluate_during_training and global_step % ( args.logging_steps_per_eval * args.logging_steps) == 0: model.eval() reranking_mrr, full_ranking_mrr = passage_dist_eval( args, model, tokenizer) if is_first_worker(): print("Reranking/Full ranking mrr: {0}/{1}".format( str(reranking_mrr), str(full_ranking_mrr))) mrr_dict = { "reranking": float(reranking_mrr), "full_raking": float(full_ranking_mrr) } tb_writer.add_scalars("mrr", mrr_dict, global_step) print(args.output_dir) loss_scalar = (tr_loss - logging_loss) / args.logging_steps learning_rate_scalar = scheduler.get_lr()[0] logs["learning_rate"] = learning_rate_scalar logs["loss"] = loss_scalar logging_loss = tr_loss acc_scalar = (tr_acc - logging_acc) / args.logging_steps logs["acc"] = acc_scalar logging_acc = tr_acc if is_first_worker(): for key, value in logs.items(): print(key, type(value)) tb_writer.add_scalar(key, value, global_step) tb_writer.add_scalar("epoch", m_epoch, global_step) print(json.dumps({**logs, **{"step": global_step}})) dist.barrier() if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") checkpoint_step = checkpoint_path.split('-')[-1].replace('/','') query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True) dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="test-query", emb_prefix="dev_query_", is_query_inference=True) dev_query_embedding_trivia, dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-test-query", emb_prefix="trivia_dev_query_", is_query_inference=True) real_dev_query_embedding, real_dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-qas-query", emb_prefix="real-dev_query_", is_query_inference=True) real_dev_query_embedding_trivia, real_dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-dev-qas-query", emb_prefix="trivia_real-dev_query_", is_query_inference=True) # passage_embedding == None, if args.split_ann_search == True passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb= not args.split_ann_search) if args.gpu_index: del model # leave gpu for faiss torch.cuda.empty_cache() time.sleep(10) if args.local_rank != -1: dist.barrier() # if None, reloading if passage_embedding2id is None and is_first_worker(): _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=None, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False) logger.info(f"document id size: {passage_embedding2id.shape}") if is_first_worker(): # passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, dev_answers, dev_answers_trivia = preloaded_data if not args.split_ann_search: dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(passage_embedding) logger.info("***** Done ANN Index *****") _, dev_I = index.search(dev_query_embedding, 100) #I: [number of queries, topk] _, dev_I_trivia = index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] logger.info("Start searching for query embedding with length %d", len(query_embedding)) _, I = index.search(query_embedding, top_k) #I: [number of queries, topk] else: _, dev_I_trivia, real_dev_D, real_dev_I = document_split_faiss_index( logger=logger, args=args, top_k_dev=100, top_k=args.topk_training, checkpoint_step=checkpoint_step, dev_query_emb=dev_query_embedding_trivia, train_query_emb=real_dev_query_embedding, emb_prefix="passage_",two_query_set=True, ) dev_D, dev_I, _, I = document_split_faiss_index( logger=logger, args=args, top_k_dev=100, top_k=args.topk_training, checkpoint_step=checkpoint_step, dev_query_emb=dev_query_embedding, train_query_emb=query_embedding, emb_prefix="passage_",two_query_set=True, ) save_trec_file( dev_query_embedding2id,passage_embedding2id,dev_I,dev_D, trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-test_" + checkpoint_step + ".trec")), topN=100 ) save_trec_file( real_dev_query_embedding2id,passage_embedding2id,real_dev_I,real_dev_D, trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-dev_" + checkpoint_step + ".trec")), topN=100 ) # measure ANN mrr top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) real_dev_top_k_hits = validate(passage_text, dev_answers, real_dev_I, real_dev_query_embedding2id, passage_embedding2id) top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I_trivia, dev_query_embedding2id_trivia, passage_embedding2id) query_range_number = I.shape[0] json_dump_dict = { 'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 'dev_top20': real_dev_top_k_hits[19], 'dev_top100': real_dev_top_k_hits[99], 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path, 'n_train_query':query_range_number, } logger.info(json_dump_dict) logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) if args.dual_training: assert args.split_ann_search and args.gpu_index # hard set logger.info("***** Begin ANN Index for dual d2q task *****") top_k = args.topk_training faiss.omp_set_num_threads(args.faiss_omp_num_threads) logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus())) cpu_index = faiss.IndexFlatIP(query_embedding.shape[1]) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(query_embedding) logger.info("***** Done building ANN Index for dual d2q task *****") # train_pos_id : a list, idx -> int pid train_pos_id_inversed = {} for qidx in range(query_embedding2id.shape[0]): qid = query_embedding2id[qidx] pid = int(train_pos_id[qid]) if pid not in train_pos_id_inversed: train_pos_id_inversed[pid]=[qid] else: train_pos_id_inversed[pid].append(qid) possitive_training_passage_id = [ train_pos_id[t] for t in query_embedding2id] # # compatible with MaxP possitive_training_passage_id_embidx=[] possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs possitive_training_passage_id_emb_counts=0 for pos_pid in possitive_training_passage_id: embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten() possitive_training_passage_id_embidx.append(embidx) possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0])) possitive_training_passage_id_emb_counts += embidx.shape[0] possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0) if not args.split_ann_search: D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) else: positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",) assert positive_p_embs.shape[0] == len(possitive_training_passage_id) D, I = index.search(positive_p_embs, args.topk_training_d2q) positive_p_embs = None del positive_p_embs index.reset() logger.info("***** Finish ANN searching for dual d2q task, construct *****") passage_negative_queries = GenerateNegativeQueryID(args, passage_text,train_answers, query_embedding2id, passage_embedding2id[possitive_training_passage_id_embidx], closest_ans=I, training_query_positive_id_inversed=train_pos_id_inversed) logger.info("***** Done ANN searching for negative queries *****") logger.info("***** Construct ANN Triplet *****") prefix = "ann_grouped_training_data_" if args.grouping_ann_data > 0 else "ann_training_data_" train_data_output_path = os.path.join( args.output_dir, prefix + str(output_num)) query_range = list(range(query_range_number)) random.shuffle(query_range) if args.grouping_ann_data > 0 : with open(train_data_output_path, 'w') as f: counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} for query_idx in query_range: query_id = query_embedding2id[query_idx] pos_pid = train_pos_id[query_id] pos_q_group[counting]=int(query_id) pos_d_group[counting]=int(pos_pid) neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]] if args.dual_training: neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]] counting +=1 if counting >= args.grouping_ann_data: jsonline_dict={} jsonline_dict["pos_q_group"]=pos_q_group jsonline_dict["pos_d_group"]=pos_d_group jsonline_dict["neg_D_group"]=neg_D_group if args.dual_training: jsonline_dict["neg_Q_group"]=neg_Q_group f.write(f"{json.dumps(jsonline_dict)}\n") counting=0 pos_q_group={} pos_d_group={} neg_D_group={} # {0:[], 1:[], 2:[]...} if args.dual_training: neg_Q_group={} else: # not support dualtraining with open(train_data_output_path, 'w') as f: for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] if not args.dual_training: f.write( "{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]))) else: # if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed: # continue f.write( "{}\t{}\t{}\t{}\n".format( query_id, pos_pid, ','.join( str(neg_pid) for neg_pid in query_negative_passage[query_id]), ','.join( str(neg_qid) for neg_qid in passage_negative_queries[pos_pid]) ) ) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump(json_dump_dict, f)
def generate_new_ann(args, checkpoint_path): if args.gpu_index: clean_faiss_gpu() if not args.not_load_model_for_inference: config, tokenizer, model = load_model(args, checkpoint_path) checkpoint_step = checkpoint_path.split('-')[-1].replace('/', '') def evaluation(dev_query_embedding2id, passage_embedding2id, dev_I, dev_D, trec_prefix="real-dev_query_", test_set="trec2019", split_idx=-1, d2q_eval=False, d2q_qrels=None): if d2q_eval: qrels = d2q_qrels else: if args.data_type == 0: if not d2q_eval: if test_set == "marcodev": qrels = "../data/raw_data/msmarco-docdev-qrels.tsv" elif test_set == "trec2019": qrels = "../data/raw_data/2019qrels-docs.txt" elif args.data_type == 1: if test_set == "marcodev": qrels = "../data/raw_data/qrels.dev.small.tsv" else: logging.error("wrong data type") exit() trec_path = os.path.join(args.output_dir, trec_prefix + str(checkpoint_step) + ".trec") save_trec_file(dev_query_embedding2id, passage_embedding2id, dev_I, dev_D, trec_save_path=trec_path, topN=200) convert_trec_to_MARCO_id(data_type=args.data_type, test_set=test_set, processed_data_dir=args.data_dir, trec_path=trec_path, d2q_reversed_trec_file=d2q_eval) trec_path = trec_path.replace(".trec", ".formatted.trec") met = Metric() if split_idx >= 0: split_file_path = qrels + f"{args.dev_split_num}_fold.split_dict" with open(split_file_path, 'rb') as f: split = pickle.load(f) else: split = None ndcg10 = met.get_metric(qrels, trec_path, 'ndcg_cut_10', split, split_idx) mrr10 = met.get_mrr(qrels, trec_path, 'mrr_cut_10', split, split_idx) mrr100 = met.get_mrr(qrels, trec_path, 'mrr_cut_100', split, split_idx) logging.info( f" evaluation for {test_set}, trec_file {trec_path}, split_idx {split_idx} \ ndcg_cut_10 : {ndcg10}, \ mrr_cut_10 : {mrr10}, \ mrr_cut_100 : {mrr100}") return ndcg10 # Inference if args.data_type == 0: # TREC DL 2019 evalset trec2019_query_embedding, trec2019_query_embedding2id = inference_or_load_embedding( args=args, logger=logger, model=model, checkpoint_path=checkpoint_path, text_data_prefix="dev-query", emb_prefix="dev_query_", is_query_inference=True) # it's trec-dl testset actually dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding( args=args, logger=logger, model=model, checkpoint_path=checkpoint_path, text_data_prefix="real-dev-query", emb_prefix="real-dev_query_", is_query_inference=True) query_embedding, query_embedding2id = inference_or_load_embedding( args=args, logger=logger, model=model, checkpoint_path=checkpoint_path, text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True) if not args.split_ann_search: # merge all passage passage_embedding, passage_embedding2id = inference_or_load_embedding( args=args, logger=logger, model=model, checkpoint_path=checkpoint_path, text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False) else: # keep id only _, passage_embedding2id = inference_or_load_embedding( args=args, logger=logger, model=model, checkpoint_path=checkpoint_path, text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False, load_emb=False) # FirstP shape, # passage_embedding: [[vec_0], [vec_1], [vec_2], [vec_3] ...], # passage_embedding2id: [id0, id1, id2, id3, ...] # MaxP shape, # passage_embedding: [[vec_0_0], [vec_0_1],[vec_0_2],[vec_0_3],[vec_1_0],[vec_1_1] ...], # passage_embedding2id: [id0, id0, id0, id0, id1, id1 ...] if args.gpu_index: del model # leave gpu for faiss torch.cuda.empty_cache() time.sleep(10) if not is_first_worker(): return else: if not args.split_ann_search: dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(args.faiss_omp_num_threads) cpu_index = faiss.IndexFlatIP(dim) logger.info("***** Faiss: total {} gpus *****".format( faiss.get_num_gpus())) index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(passage_embedding) # for measure ANN mrr logger.info("search dev query") dev_D, dev_I = index.search(dev_query_embedding, 100) # I: [number of queries, topk] logger.info("finish") logger.info("search train query") D, I = index.search(query_embedding, top_k) # I: [number of queries, topk] logger.info("finish") index.reset() else: if args.data_type == 0: trec2019_D, trec2019_I, _, _ = document_split_faiss_index( logger=logger, args=args, checkpoint_step=checkpoint_step, top_k_dev=200, top_k=args.topk_training, dev_query_emb=trec2019_query_embedding, train_query_emb=None, emb_prefix="passage_", two_query_set=False, ) dev_D, dev_I, D, I = document_split_faiss_index( logger=logger, args=args, checkpoint_step=checkpoint_step, top_k_dev=200, top_k=args.topk_training, dev_query_emb=dev_query_embedding, train_query_emb=query_embedding, emb_prefix="passage_") logger.info("***** seperately process indexing *****") logger.info("***** Done ANN Index *****") # dev_ndcg, num_queries_dev = EvalDevQuery( # args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) logger.info("***** Begin evaluation *****") eval_dict_todump = {'checkpoint': checkpoint_path} if args.data_type == 0: trec2019_ndcg = evaluation(trec2019_query_embedding2id, passage_embedding2id, trec2019_I, trec2019_D, trec_prefix="dev_query_", test_set="trec2019") if args.dev_split_num > 0: marcodev_ndcg = 0.0 for i in range(args.dev_split_num): ndcg_10_dev_split_i = evaluation(dev_query_embedding2id, passage_embedding2id, dev_I, dev_D, trec_prefix="real-dev_query_", test_set="marcodev", split_idx=i) if i != args.testing_split_idx: marcodev_ndcg += ndcg_10_dev_split_i eval_dict_todump[ f'marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i logger.info( f"average marco dev { marcodev_ndcg /(args.dev_split_num -1)}") else: marcodev_ndcg = evaluation(dev_query_embedding2id, passage_embedding2id, dev_I, dev_D, trec_prefix="real-dev_query_", test_set="marcodev", split_idx=-1) eval_dict_todump['marcodev_ndcg'] = marcodev_ndcg if args.save_training_query_trec: logger.info( "***** Save the ANN searching for negative passages in trec file format *****" ) trec_output_path = os.path.join( args.output_dir, "ann_training_query_retrieval_" + str(checkpoint_step) + ".trec") save_trec_file(query_embedding2id, passage_embedding2id, I, D, trec_output_path, topN=args.topk_training) convert_trec_to_MARCO_id(data_type=args.data_type, test_set="training", processed_data_dir=args.data_dir, trec_path=trec_output_path, d2q_reversed_trec_file=False) logger.info("***** Done ANN searching for negative passages *****") if args.d2q_task_evaluation and args.d2q_task_marco_dev_qrels is not None: with open(os.path.join(args.data_dir, 'pid2offset.pickle'), 'rb') as f: pid2offset = pickle.load(f) real_dev_ANCE_ids = [] with open( args.d2q_task_marco_dev_qrels + f"{args.dev_split_num}_fold.split_dict", "rb") as f: dev_d2q_split_dict = pickle.load(f) for i in dev_d2q_split_dict: for stringdocid in dev_d2q_split_dict[i]: if args.data_type == 0: real_dev_ANCE_ids.append(pid2offset[int( stringdocid[1:])]) else: real_dev_ANCE_ids.append(pid2offset[int(stringdocid)]) real_dev_ANCE_ids = np.array(real_dev_ANCE_ids).flatten() real_dev_possitive_training_passage_id_embidx = [] for dev_pos_pid in real_dev_ANCE_ids: embidx = np.asarray( np.where(passage_embedding2id == dev_pos_pid)).flatten() real_dev_possitive_training_passage_id_embidx.append(embidx) # possitive_training_passage_id_to_subset_embidx[int(dev_pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0])) # possitive_training_passage_id_emb_counts += embidx.shape[0] real_dev_possitive_training_passage_id_embidx = np.concatenate( real_dev_possitive_training_passage_id_embidx, axis=0) del pid2offset if not args.split_ann_search: real_dev_positive_p_embs = passage_embedding[ real_dev_possitive_training_passage_id_embidx] else: real_dev_positive_p_embs = loading_possitive_document_embedding( logger, args.output_dir, checkpoint_step, real_dev_possitive_training_passage_id_embidx, emb_prefix="passage_", ) logger.info("***** d2q task evaluation *****") cpu_index = faiss.IndexFlatIP(dev_query_embedding.shape[1]) index = cpu_index # index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index index.add(dev_query_embedding) real_dev_d2q_D, real_dev_d2q_I = index.search( real_dev_positive_p_embs, 200) if args.dev_split_num > 0: d2q_marcodev_ndcg = 0.0 for i in range(args.dev_split_num): d2q_ndcg_10_dev_split_i = evaluation( real_dev_ANCE_ids, dev_query_embedding2id, real_dev_d2q_I, real_dev_d2q_D, trec_prefix="d2q-dual-task_real-dev_query_", test_set="marcodev", split_idx=i, d2q_eval=True, d2q_qrels=args.d2q_task_marco_dev_qrels) if i != args.testing_split_idx: d2q_marcodev_ndcg += d2q_ndcg_10_dev_split_i eval_dict_todump[ f'd2q_marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i logger.info( f"average marco dev d2q task { d2q_marcodev_ndcg /(args.dev_split_num -1)}" ) else: d2q_marcodev_ndcg = evaluation( real_dev_ANCE_ids, dev_query_embedding2id, real_dev_d2q_I, real_dev_d2q_D, trec_prefix="d2q-dual-task_real-dev_query_", test_set="marcodev", split_idx=-1, d2q_eval=True, d2q_qrels=args.d2q_task_marco_dev_qrels) eval_dict_todump['d2q_marcodev_ndcg'] = d2q_marcodev_ndcg return None #dev_ndcg, num_queries_dev
def train(args, model, tokenizer, f, train_fn): """ Train the model """ tb_writer = None if is_first_worker(): tb_writer = SummaryWriter(log_dir=args.log_dir) args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu) real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \ (torch.distributed.get_world_size() if args.local_rank != -1 else 1) if args.max_steps > 0: t_total = args.max_steps else: t_total = args.expected_train_size // real_batch_size * args.num_train_epochs print('????t_total', t_total) # layerwise optimization for lamb optimizer_grouped_parameters = [] layer_optim_params = set() for layer_name in [ "roberta.embeddings", "score_out", "downsample1", "downsample2", "downsample3", "embeddingHead" ]: layer = getattr_recursive(model, layer_name) if layer is not None: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) if getattr_recursive(model, "roberta.encoder.layer") is not None: for layer in model.roberta.encoder.layer: optimizer_grouped_parameters.append({"params": layer.parameters()}) for p in layer.parameters(): layer_optim_params.add(p) # if getattr_recursive(model, "roberta.encoder.layer") is not None: # for layer in model.roberta.encoder.layer: # optimizer_grouped_parameters.append({"params": layer.parameters()}) # for p in layer.parameters(): # layer_optim_params.add(p) optimizer_grouped_parameters.append({ "params": [p for p in model.parameters() if p not in layer_optim_params] }) if args.optimizer.lower() == "lamb": optimizer = Lamb(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) elif args.optimizer.lower() == "adamw": optimizer = AdamW(optimizer_grouped_parameters, lr=args.learning_rate, eps=args.adam_epsilon) else: raise Exception( "optimizer {0} not recognized! Can only be lamb or adamW".format( args.optimizer)) if args.scheduler.lower() == "linear": scheduler = get_linear_schedule_with_warmup( optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total) elif args.scheduler.lower() == "cosine": scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8) else: raise Exception( "Scheduler {0} not recognized! Can only be linear or cosine". format(args.scheduler)) # Check if saved optimizer or scheduler states exist if os.path.isfile(os.path.join( args.model_name_or_path, "optimizer.pt")) and os.path.isfile( os.path.join( args.model_name_or_path, "scheduler.pt")) and args.load_optimizer_scheduler: # Load in optimizer and scheduler states optimizer.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"))) scheduler.load_state_dict( torch.load(os.path.join(args.model_name_or_path, "scheduler.pt"))) if args.fp16: try: from apex import amp except ImportError: raise ImportError( "Please install apex from https://www.github.com/nvidia/apex to use fp16 training." ) model, optimizer = amp.initialize(model, optimizer, opt_level=args.fp16_opt_level) # multi-gpu training (should be after apex fp16 initialization) if args.n_gpu > 1: model = torch.nn.DataParallel(model) # Distributed training (should be after apex fp16 initialization) if args.local_rank != -1: model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[args.local_rank], output_device=args.local_rank, find_unused_parameters=True, ) # Train! logger.info("***** Running training *****") logger.info(" Num Epochs = %d", args.num_train_epochs) logger.info(" Instantaneous batch size per GPU = %d", args.per_gpu_train_batch_size) logger.info( " Total train batch size (w. parallel, distributed & accumulation) = %d", args.train_batch_size * args.gradient_accumulation_steps * (torch.distributed.get_world_size() if args.local_rank != -1 else 1), ) logger.info(" Gradient Accumulation steps = %d", args.gradient_accumulation_steps) logger.info(" Total optimization steps = %d", t_total) global_step = 0 epochs_trained = 0 steps_trained_in_current_epoch = 0 # Check if continuing training from a checkpoint if os.path.exists(args.model_name_or_path): # set global_step to gobal_step of last saved checkpoint from model path try: global_step = int( args.model_name_or_path.split("-")[-1].split("/")[0]) epochs_trained = global_step // (args.expected_train_size // args.gradient_accumulation_steps) steps_trained_in_current_epoch = global_step % ( args.expected_train_size // args.gradient_accumulation_steps) logger.info( " Continuing training from checkpoint, will skip to saved global_step" ) logger.info(" Continuing training from epoch %d", epochs_trained) logger.info(" Continuing training from global step %d", global_step) logger.info(" Will skip the first %d steps in the first epoch", steps_trained_in_current_epoch) except: logger.info(" Start training from a pretrained model") tr_loss, logging_loss = 0.0, 0.0 model.zero_grad() train_iterator = trange( epochs_trained, int(args.num_train_epochs), desc="Epoch", disable=args.local_rank not in [-1, 0], ) set_seed(args) # Added here for reproductibility #print('???',args.local_rank) #assert 1==0, "?????" for m_epoch in train_iterator: f.seek(0) sds = StreamingDataset(f, train_fn) epoch_iterator = DataLoader(sds, batch_size=args.per_gpu_train_batch_size, num_workers=1) count = 0 avg_cls_norm = 0 loss_avg = 0 for step, batch in tqdm(enumerate(epoch_iterator), desc="Iteration", disable=args.local_rank not in [-1, 0]): #assert 1==0, "?????" # Skip past any already trained steps if resuming training #assert 1==0, steps_trained_in_current_epoch # if not args.reset_iter: # if steps_trained_in_current_epoch > 0: # steps_trained_in_current_epoch -= 1 # continue model.train() batch = tuple(t.to(args.device).long() for t in batch) # print('???',*batch) # assert 1==0, "!!!!!" with torch.no_grad(): outputs = model(*batch) cls_norm = outputs[1] loss = outputs[0] count += 1 avg_cls_norm += float(cls_norm.cpu().data) loss_avg += float(loss.cpu().data) print( "SEED-Encoder norm: ", cls_norm, ) #print("loss: ",loss) #assert 1==0 #print("optimus norm: ",cls_norm) if count == 1024: # print('avg_cls_norm: ',float(avg_cls_norm)/count) print('avg_cls_sim: ', float(avg_cls_norm) / count) print('avg_loss: ', float(loss_avg) / count) return if args.max_steps > 0 and global_step > args.max_steps: train_iterator.close() break if args.local_rank == -1 or torch.distributed.get_rank() == 0: tb_writer.close() return global_step, tr_loss / global_step
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num): model = load_model(args, checkpoint_path) pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset") logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of dev query *****") dev_query_collection_path = os.path.join(args.data_dir, "test-query") dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query") dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia) with dev_query_cache_trivia as emb: dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False) logger.info("***** Done passage inference *****") if is_first_worker(): passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk] top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id) # measure ANN mrr _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk] top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id) logger.info("Start searching for query embedding with length %d", len(query_embedding)) _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk] logger.info("***** GenerateNegativePassaageID *****") effective_q_id = set(query_embedding2id.flatten()) logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0]) query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id) logger.info("Done generating negative passages, output length %d", len(query_negative_passage)) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num)) with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] # if not query_id in train_pos_id: # continue pos_pid = train_pos_id[query_id] f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id]))) ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
def generate_new_ann( args, output_num, checkpoint_path, training_query_positive_id, dev_query_positive_id, latest_step_num): config, tokenizer, model = load_model(args, checkpoint_path) dataFound = False if args.end_output_num == 0 and is_first_worker(): dataFound, query_embedding, query_embedding2id, dev_query_embedding, dev_query_embedding2id, passage_embedding, passage_embedding2id = load_init_embeddings(args) if args.local_rank != -1: dist.barrier() if not dataFound: logger.info("***** inference of dev query *****") name_ = None if args.dataset == "dl_test_2019": dev_query_collection_path = os.path.join(args.data_dir, "test-query") name_ = "test_query_" elif args.dataset == "dl_test_2019_1": dev_query_collection_path = os.path.join(args.data_dir, "test-query-1") name_ = "test_query_" elif args.dataset == "dl_test_2019_2": dev_query_collection_path = os.path.join(args.data_dir, "test-query-2") name_ = "test_query_" else: raise Exception('Dataset should be one of {dl_test_2019, dl_test_2019_1, dl_test_2019_2}!!') dev_query_cache = EmbeddingCache(dev_query_collection_path) with dev_query_cache as emb: dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), name_ + str(latest_step_num) + "_", emb, is_query_inference=True) logger.info("***** inference of passages *****") passage_collection_path = os.path.join(args.data_dir, "passages") passage_cache = EmbeddingCache(passage_collection_path) with passage_cache as emb: passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False) logger.info("***** Done passage inference *****") if args.inference: return logger.info("***** inference of train query *****") train_query_collection_path = os.path.join(args.data_dir, "train-query") train_query_cache = EmbeddingCache(train_query_collection_path) with train_query_cache as emb: query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn( args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True) else: logger.info("***** Found pre-existing embeddings. So not running inference again. *****") if is_first_worker(): dim = passage_embedding.shape[1] print('passage embedding shape: ' + str(passage_embedding.shape)) top_k = args.topk_training faiss.omp_set_num_threads(16) cpu_index = faiss.IndexFlatIP(dim) cpu_index.add(passage_embedding) logger.info("***** Done ANN Index *****") # measure ANN mrr # I: [number of queries, topk] _, dev_I = cpu_index.search(dev_query_embedding, 100) dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I) # Construct new traing set ================================== chunk_factor = args.ann_chunk_factor effective_idx = output_num % chunk_factor if chunk_factor <= 0: chunk_factor = 1 num_queries = len(query_embedding) queries_per_chunk = num_queries // chunk_factor q_start_idx = queries_per_chunk * effective_idx q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else (q_start_idx + queries_per_chunk) query_embedding = query_embedding[q_start_idx:q_end_idx] query_embedding2id = query_embedding2id[q_start_idx:q_end_idx] logger.info( "Chunked {} query from {}".format( len(query_embedding), num_queries)) # I: [number of queries, topk] _, I = cpu_index.search(query_embedding, top_k) effective_q_id = set(query_embedding2id.flatten()) _, dev_I_dist = cpu_index.search(dev_query_embedding, top_k) distrib, samplingDist = getSamplingDist(args, dev_I_dist, dev_query_embedding2id, dev_query_positive_id, passage_embedding2id) sampling_dist_data = {'distrib': distrib, 'samplingDist': samplingDist} dist_output_path = os.path.join(args.output_dir, "dist_" + str(output_num)) with open(dist_output_path, 'wb') as f: pickle.dump(sampling_dist_data, f) query_negative_passage = GenerateNegativePassaageID( args, query_embedding, query_embedding2id, passage_embedding, passage_embedding2id, training_query_positive_id, I, effective_q_id, samplingDist, output_num) logger.info("***** Construct ANN Triplet *****") train_data_output_path = os.path.join( args.output_dir, "ann_training_data_" + str(output_num)) train_query_cache.open() with open(train_data_output_path, 'w') as f: query_range = list(range(I.shape[0])) random.shuffle(query_range) for query_idx in query_range: query_id = query_embedding2id[query_idx] if query_id not in effective_q_id or query_id not in training_query_positive_id: continue pos_pid = training_query_positive_id[query_id] pos_score = get_BM25_score(query_id, pos_pid, train_query_cache, tokenizer) neg_scores = {} for neg_pid in query_negative_passage[query_id]: neg_scores[neg_pid] = get_BM25_score(query_id, neg_pid, train_query_cache, tokenizer) f.write( "{}\t{}\t{}\n".format( query_id, str(pos_pid)+":"+str(round(pos_score,3)), ','.join( str(neg_pid)+":"+str(round(neg_scores[neg_pid],3)) for neg_pid in query_negative_passage[query_id]))) train_query_cache.close() ndcg_output_path = os.path.join( args.output_dir, "ann_ndcg_" + str(output_num)) with open(ndcg_output_path, 'w') as f: json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f) return dev_ndcg, num_queries_dev