示例#1
0
def ann_data_gen(args):
    last_checkpoint = args.last_checkpoint_dir
    ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir)

    if is_first_worker():
        logger.info("Getting bm25_helper")
        global bm25_helper
        bm25_helper = BM25_helper(args)
        logger.info("Done loading bm25_helper")

    output_num = ann_no + 1

    logger.info("starting output number %d", output_num)

    if is_first_worker():
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        if not os.path.exists(args.cache_dir):
            os.makedirs(args.cache_dir)

    training_positive_id, dev_positive_id = load_positive_ids(args)

    while args.end_output_num == -1 or output_num <= args.end_output_num:

        next_checkpoint, latest_step_num = get_latest_checkpoint(args)

        if args.only_keep_latest_embedding_file:
            latest_step_num = 0

        if next_checkpoint == last_checkpoint:
            print("Sleeping for 1 hr")
            time.sleep(3600)
        else:
            logger.info("start generate ann data number %d", output_num)
            logger.info("next checkpoint at " + next_checkpoint)
            generate_new_ann(
                args,
                output_num,
                next_checkpoint,
                training_positive_id,
                dev_positive_id,
                latest_step_num)
            if args.inference:
                break
            logger.info("finished generating ann data number %d", output_num)
            output_num += 1
            last_checkpoint = next_checkpoint
        if args.local_rank != -1:
            dist.barrier()
示例#2
0
def ann_data_gen(args):
    last_checkpoint = args.last_checkpoint_dir
    ann_no, ann_path, ndcg_json = get_latest_ann_data(args.output_dir)
    output_num = ann_no + 1

    logger.info("starting output number %d", output_num)
    preloaded_data = None
    
    if is_first_worker():
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        if not os.path.exists(args.cache_dir):
            os.makedirs(args.cache_dir)
        preloaded_data = load_data(args)

    while args.end_output_num == -1 or output_num <= args.end_output_num:
        next_checkpoint, latest_step_num = get_latest_checkpoint(args)
        logger.info(f"get next_checkpoint {next_checkpoint} latest_step_num {latest_step_num} ")
        if args.only_keep_latest_embedding_file:
            latest_step_num = 0

        if next_checkpoint == last_checkpoint:
            time.sleep(60)
        else:
            logger.info("start generate ann data number %d", output_num)
            logger.info("next checkpoint at " + next_checkpoint)
            generate_new_ann(args, output_num, next_checkpoint, preloaded_data, latest_step_num)
            logger.warning("process rank: %s, finished generating ann data number %d", args.local_rank, output_num)
            # logger.info("finished generating ann data number %d", output_num)
            output_num += 1
            last_checkpoint = next_checkpoint
        if args.local_rank != -1:
            dist.barrier()
示例#3
0
def save_checkpoint(args, model, tokenizer):
    # Saving best-practices: if you use defaults names for the model, you can reload it using from_pretrained()
    if args.do_train and is_first_worker():
        # Create output directory if needed
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        logger.info("Saving model checkpoint to %s", args.output_dir)
        # Save a trained model, configuration and tokenizer using `save_pretrained()`.
        # They can then be reloaded using `from_pretrained()`

        if 'fairseq' not in args.train_model_type:
            model_to_save = (
                model.module if hasattr(model, "module") else model
            )  # Take care of distributed/parallel training
            model_to_save.save_pretrained(output_dir)
            tokenizer.save_pretrained(output_dir)
        else:
            torch.save(model.state_dict(),
                       os.path.join(output_dir, 'model.pt'))

        # Good practice: save your training arguments together with the trained model
        torch.save(args, os.path.join(args.output_dir, "training_args.bin"))

    dist.barrier()
示例#4
0
def evaluation(args, model, tokenizer):
    # Evaluation
    results = {}
    if args.do_eval:
        model_dir = args.model_name_or_path if args.model_name_or_path else args.output_dir

        checkpoints = [model_dir]

        for checkpoint in checkpoints:
            global_step = checkpoint.split(
                "-")[-1] if len(checkpoints) > 1 else ""
            prefix = checkpoint.split(
                "/")[-1] if checkpoint.find("checkpoint") != -1 else ""

            model.eval()
            recall = passage_dist_eval_last(args, model, tokenizer)
            print('recall@1000: ', recall)
            reranking_mrr, full_ranking_mrr = passage_dist_eval(
                args, model, tokenizer)

            if is_first_worker():
                print("Reranking/Full ranking mrr: {0}/{1}".format(
                    str(reranking_mrr), str(full_ranking_mrr)))

            dist.barrier()
    return results
def ann_data_gen(args):
    if is_first_worker():
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        if not os.path.exists(args.cache_dir):
            os.makedirs(args.cache_dir)

    training_positive_id, dev_positive_id = load_positive_ids(args)

    finished_checkpoint_list = []

    while True:
        all_checkpoint_lists = get_all_checkpoint(
            args)  # include init_checkpoint
        logger.info("get all the checkpoints list:\n %s", all_checkpoint_lists)
        for checkpoint_path in all_checkpoint_lists:
            if checkpoint_path not in finished_checkpoint_list:
                logger.info(
                    f"inference and eval for checkpoint at {checkpoint_path}")
                generate_new_ann(args, checkpoint_path)
                logger.info(
                    f"finished generating ann data number at {checkpoint_path}"
                )
                finished_checkpoint_list.append(checkpoint_path)
            if args.local_rank != -1:
                dist.barrier()

        if args.inference_one_specified_ckpt:
            break

        time.sleep(600)
示例#6
0
def compute_mrr(D, I, qids, ref_dict):
    knn_pkl = {"D": D, "I": I}
    all_knn_list = all_gather(knn_pkl)
    mrr = 0.0
    if is_first_worker():
        D_merged = concat_key(all_knn_list, "D", axis=1)
        I_merged = concat_key(all_knn_list, "I", axis=1)
        print(D_merged.shape, I_merged.shape)
        # we pad with negative pids and distance -128 - if they make it to the top we have a problem
        idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :10]
        sorted_I = np.take_along_axis(I_merged, idx, axis=1)
        candidate_dict = {}
        for i, qid in enumerate(qids):
            seen_pids = set()
            if qid not in candidate_dict:
                candidate_dict[qid] = [0] * 1000
            j = 0
            for pid in sorted_I[i]:
                if pid >= 0 and pid not in seen_pids:
                    candidate_dict[qid][j] = pid
                    j += 1
                    seen_pids.add(pid)

        allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        if message != '':
            print(message)

        mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        mrr = mrr_metrics["MRR @10"]
        print(mrr)
    return mrr
示例#7
0
def compute_mrr_last(D, I, qids, ref_dict, dev_query_positive_id):
    knn_pkl = {"D": D, "I": I}
    all_knn_list = all_gather(knn_pkl)
    mrr = 0.0
    final_recall = 0.0
    if is_first_worker():
        prediction = {}
        D_merged = concat_key(all_knn_list, "D", axis=1)
        I_merged = concat_key(all_knn_list, "I", axis=1)
        print(D_merged.shape, I_merged.shape)
        # we pad with negative pids and distance -128 - if they make it to the top we have a problem
        idx = np.argsort(D_merged, axis=1)[:, ::-1][:, :1000]
        sorted_I = np.take_along_axis(I_merged, idx, axis=1)
        candidate_dict = {}
        for i, qid in enumerate(qids):
            seen_pids = set()
            if qid not in candidate_dict:
                prediction[qid] = {}
                candidate_dict[qid] = [0] * 1000
            j = 0
            for pid in sorted_I[i]:
                if pid >= 0 and pid not in seen_pids:
                    candidate_dict[qid][j] = pid
                    prediction[qid][pid] = -(j + 1)  #-rank
                    j += 1
                    seen_pids.add(pid)

        # allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        # if message != '':
        #     print(message)

        # mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        # mrr = mrr_metrics["MRR @10"]
        # print(mrr)
        allowed, message = quality_checks_qids(ref_dict, candidate_dict)
        if message != '':
            print(message)

        mrr_metrics = compute_metrics(ref_dict, candidate_dict)
        mrr = mrr_metrics["MRR @10"]
        print(mrr)

        evaluator = pytrec_eval.RelevanceEvaluator(
            convert_to_string_id(dev_query_positive_id), {'recall'})

        eval_query_cnt = 0
        recall = 0
        topN = 1000
        result = evaluator.evaluate(convert_to_string_id(prediction))
        for k in result.keys():
            eval_query_cnt += 1
            recall += result[k]["recall_" + str(topN)]

        final_recall = recall / eval_query_cnt
        print('final_recall: ', final_recall)

    return mrr, final_recall
def ann_data_gen(args):
    if is_first_worker():
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)
        if not os.path.exists(args.cache_dir):
            os.makedirs(args.cache_dir)

    checkpoint_path = args.init_model_dir

    logger.info(f"inference and eval for checkpoint at {checkpoint_path}")
    generate_new_ann(args, checkpoint_path)
    logger.info(f"finished generating ann data number at {checkpoint_path}")
    if args.local_rank != -1:
        dist.barrier()
示例#9
0
def inference_or_load_embedding(args,logger,model,checkpoint_path,text_data_prefix, emb_prefix, is_query_inference=True,checkonly=False,load_emb=True):
    # logging.info(f"checkpoint_path {checkpoint_path}")
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')
    emb_file_pattern = os.path.join(args.output_dir,f'{emb_prefix}{checkpoint_step}__emb_p__data_obj_*.pb')
    emb_file_lists = glob.glob(emb_file_pattern)
    emb_file_lists = sorted(emb_file_lists, key=lambda name: int(name.split('_')[-1].replace('.pb',''))) # sort with split num
    logger.info(f"pattern {emb_file_pattern}\n file lists: {emb_file_lists}")
    embedding,embedding2id = [None,None]
    if len(emb_file_lists) > 0:
        if is_first_worker():
            logger.info(f"***** found existing embedding files {emb_file_pattern}, loading... *****")
            if checkonly:
                logger.info("check embedding files only, not loading")
                return embedding,embedding2id
            embedding = []
            embedding2id = []
            for emb_file in emb_file_lists:
                if load_emb:
                    with open(emb_file,'rb') as handle:
                        embedding.append(pickle.load(handle))
                embid_file = emb_file.replace('emb_p','embid_p')
                with open(embid_file,'rb') as handle:
                    embedding2id.append(pickle.load(handle))
            if (load_emb and not embedding) or (not embedding2id):
                logger.error("No data found for checkpoint: ",emb_file_pattern)
            if load_emb:
                embedding = np.concatenate(embedding, axis=0)
            embedding2id = np.concatenate(embedding2id, axis=0)
        # return embedding,embedding2id
    # else:
    #    if args.local_rank != -1:
    #        dist.barrier() # if multi-processing
    else:
        logger.info(f"***** inference of {text_data_prefix} *****")
        query_collection_path = os.path.join(args.data_dir, text_data_prefix)
        query_cache = EmbeddingCache(query_collection_path)
        with query_cache as emb:
            embedding,embedding2id = StreamInferenceDoc(args, model, 
                GetProcessingFn(args, query=is_query_inference),
                emb_prefix + str(checkpoint_step) + "_", emb,
                is_query_inference=is_query_inference)
    return embedding,embedding2id
示例#10
0
def load_model(args):
    # Prepare GLUE task
    args.task_name = args.task_name.lower()
    args.output_mode = "classification"
    label_list = ["0", "1"]
    num_labels = len(label_list)

    # store args
    if args.local_rank != -1:
        args.world_size = torch.distributed.get_world_size()
        args.rank = dist.get_rank()

    # Load pretrained model and tokenizer
    if args.local_rank not in [-1, 0]:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    args.model_type = args.model_type.lower()
    configObj = MSMarcoConfigDict[args.model_type]
    # tokenizer = configObj.tokenizer_class.from_pretrained(
    #     "bert-base-uncased",
    #     do_lower_case=True,
    #     cache_dir=args.cache_dir if args.cache_dir else None,
    # )

    if is_first_worker():
        # Create output directory if needed
        if not os.path.exists(args.output_dir):
            os.makedirs(args.output_dir)

        #if not os.path.exists(args.blob_output_dir):
        #    os.makedirs(args.blob_output_dir)

    model = configObj.model_class(args)

    if args.local_rank == 0:
        torch.distributed.barrier(
        )  # Make sure only the first process in distributed training will download model & vocab

    model.to(args.device)
    return model
示例#11
0
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        training_query_positive_id,
        dev_query_positive_id,
        latest_step_num):
    if args.gpu_index:
        clean_faiss_gpu()
    if not args.not_load_model_for_inference:
        config, tokenizer, model = load_model(args, checkpoint_path)
    
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')

    def evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="trec2019",split_idx=-1,d2q_eval=False,d2q_qrels=None):
        if d2q_eval:
            qrels=d2q_qrels
        else:
            if args.data_type ==0 :
                if not d2q_eval:
                    if  test_set== "marcodev":
                        qrels="../data/raw_data/msmarco-docdev-qrels.tsv"
                    elif test_set== "trec2019":
                        qrels="../data/raw_data/2019qrels-docs.txt"
            elif args.data_type ==1:
                if test_set == "marcodev":
                    qrels="../data/raw_data/qrels.dev.small.tsv"
            else:
                logging.error("wrong data type")
                exit()
        trec_path=os.path.join(args.output_dir, trec_prefix + str(checkpoint_step)+".trec")
        save_trec_file(
            dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,
            trec_save_path= trec_path,
            topN=200)
        convert_trec_to_MARCO_id(
            data_type=args.data_type,test_set=test_set,
            processed_data_dir=args.data_dir,
            trec_path=trec_path,d2q_reversed_trec_file=d2q_eval)

        trec_path=trec_path.replace(".trec",".formatted.trec")
        met = Metric()
        if split_idx >= 0:
            split_file_path=qrels+f"{args.dev_split_num}_fold.split_dict"
            with open(split_file_path,'rb') as f:
                split=pickle.load(f)
        else:
            split=None
        
        ndcg10 = met.get_metric(qrels, trec_path, 'ndcg_cut_10',split,split_idx)
        mrr10 = met.get_mrr(qrels, trec_path, 'mrr_cut_10',split,split_idx)
        mrr100 = met.get_mrr(qrels, trec_path, 'mrr_cut_100',split,split_idx)

        logging.info(f" evaluation for {test_set}, trec_file {trec_path}, split_idx {split_idx} \
            ndcg_cut_10 : {ndcg10}, \
            mrr_cut_10 : {mrr10}, \
            mrr_cut_100 : {mrr100}"
        )

        return ndcg10


    # Inference 
    if args.data_type==0:
        # TREC DL 2019 evalset
        trec2019_query_embedding, trec2019_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-query", emb_prefix="dev_query_", is_query_inference=True)# it's trec-dl testset actually
    dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path, text_data_prefix="real-dev-query", emb_prefix="real-dev_query_", is_query_inference=True)
    query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True)
    if not args.split_ann_search:
        # merge all passage
        passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False)
    else:
        # keep id only
        _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False)
        
    # FirstP shape,
    # passage_embedding: [[vec_0], [vec_1], [vec_2], [vec_3] ...], 
    # passage_embedding2id: [id0, id1, id2, id3, ...]

    # MaxP shape,
    # passage_embedding: [[vec_0_0], [vec_0_1],[vec_0_2],[vec_0_3],[vec_1_0],[vec_1_1] ...], 
    # passage_embedding2id: [id0, id0, id0, id0, id1, id1 ...]
    if args.gpu_index:
        del model  # leave gpu for faiss
        torch.cuda.empty_cache()
        time.sleep(10)

    if args.inference:
        return

    if is_first_worker():
        # Construct new traing subset
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (
            effective_idx == (
                chunk_factor -
                1)) else (
            q_start_idx +
            queries_per_chunk)
        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding[q_start_idx:q_end_idx]),
                num_queries))

        if not args.split_ann_search:
            dim = passage_embedding.shape[1]
            print('passage embedding shape: ' + str(passage_embedding.shape))
            top_k = args.topk_training
            faiss.omp_set_num_threads(args.faiss_omp_num_threads)
            cpu_index = faiss.IndexFlatIP(dim)
            logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus()))
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(passage_embedding)
            # for measure ANN mrr
            logger.info("search dev query")
            dev_D, dev_I = index.search(dev_query_embedding, 100) # I: [number of queries, topk]
            logger.info("finish")
            logger.info("search train query")
            D, I = index.search(query_embedding[q_start_idx:q_end_idx], top_k) # I: [number of queries, topk]
            logger.info("finish")
            index.reset()
        else:
            if args.data_type==0:
                trec2019_D, trec2019_I, _, _ = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    checkpoint_step=checkpoint_step,
                    top_k_dev = 200,
                    top_k = args.topk_training,
                    dev_query_emb=trec2019_query_embedding,
                    train_query_emb=None,
                    emb_prefix="passage_",two_query_set=False,
                )
            dev_D, dev_I, D, I = document_split_faiss_index(
                logger=logger,
                args=args,
                checkpoint_step=checkpoint_step,
                top_k_dev = 200,
                top_k = args.topk_training,
                dev_query_emb=dev_query_embedding,
                train_query_emb=query_embedding[q_start_idx:q_end_idx],
                emb_prefix="passage_")
            logger.info("***** seperately process indexing *****")
        
        
        logger.info("***** Done ANN Index *****")

        # dev_ndcg, num_queries_dev = EvalDevQuery(
        #     args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)
        logger.info("***** Begin evaluation *****")
        eval_dict_todump={'checkpoint': checkpoint_path}

        if args.data_type==0:
            trec2019_ndcg = evaluation(trec2019_query_embedding2id,passage_embedding2id,trec2019_I,trec2019_D,trec_prefix="dev_query_",test_set="trec2019")
        if args.dev_split_num > 0:
            marcodev_ndcg = 0.0
            for i in range(args.dev_split_num):
                ndcg_10_dev_split_i = evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="marcodev",split_idx=i)
                if i != args.testing_split_idx:
                    marcodev_ndcg += ndcg_10_dev_split_i
                
                eval_dict_todump[f'marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i

            logger.info(f"average marco dev { marcodev_ndcg /(args.dev_split_num -1)}")
        else:
            marcodev_ndcg = evaluation(dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,trec_prefix="real-dev_query_",test_set="marcodev",split_idx=-1)
        
        eval_dict_todump['marcodev_ndcg']=marcodev_ndcg

        query_range_number = I.shape[0]
        if args.save_training_query_trec:
            logger.info("***** Save the ANN searching for negative passages in trec file format *****")    
            trec_output_path=os.path.join(args.output_dir, "ann_training_query_retrieval_" + str(output_num)+".trec")
            save_trec_file(query_embedding2id[q_start_idx:q_end_idx],passage_embedding2id,I,D,trec_output_path,topN=args.topk_training)

        effective_q_id = set(query_embedding2id[q_start_idx:q_end_idx].flatten())
        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding2id[q_start_idx:q_end_idx],
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id)
        logger.info("***** Done ANN searching for negative passages *****")

        if args.d2q_task_evaluation and args.d2q_task_marco_dev_qrels is not None:
            with open(os.path.join(args.data_dir,'pid2offset.pickle'),'rb') as f:
                pid2offset = pickle.load(f)
            real_dev_ANCE_ids=[]
            with open(args.d2q_task_marco_dev_qrels+f"{args.dev_split_num}_fold.split_dict","rb") as f:
                dev_d2q_split_dict=pickle.load(f)
            for i in dev_d2q_split_dict:
                for stringdocid in dev_d2q_split_dict[i]:
                    if args.data_type==0:
                        real_dev_ANCE_ids.append(pid2offset[int(stringdocid[1:])])
                    else:
                        real_dev_ANCE_ids.append(pid2offset[int(stringdocid)])
            real_dev_ANCE_ids = np.array(real_dev_ANCE_ids).flatten()
            real_dev_possitive_training_passage_id_embidx=[]
            for dev_pos_pid in real_dev_ANCE_ids:
                embidx=np.asarray(np.where(passage_embedding2id==dev_pos_pid)).flatten()
                real_dev_possitive_training_passage_id_embidx.append(embidx)
                # possitive_training_passage_id_to_subset_embidx[int(dev_pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0]))
                # possitive_training_passage_id_emb_counts += embidx.shape[0]
            real_dev_possitive_training_passage_id_embidx=np.concatenate(real_dev_possitive_training_passage_id_embidx,axis=0)
            del pid2offset
            if not args.split_ann_search:
                real_dev_positive_p_embs = passage_embedding[real_dev_possitive_training_passage_id_embidx]
            else:
                real_dev_positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,real_dev_possitive_training_passage_id_embidx,emb_prefix="passage_",)
            logger.info("***** d2q task evaluation *****")
            cpu_index = faiss.IndexFlatIP(dev_query_embedding.shape[1])
            index = cpu_index
            # index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(dev_query_embedding)
            real_dev_d2q_D, real_dev_d2q_I = index.search(real_dev_positive_p_embs, 200) 
            if args.dev_split_num > 0:
                d2q_marcodev_ndcg = 0.0
                for i in range(args.dev_split_num):
                    d2q_ndcg_10_dev_split_i = evaluation( 
                        real_dev_ANCE_ids,dev_query_embedding2id ,real_dev_d2q_I,real_dev_d2q_D,
                        trec_prefix="d2q-dual-task_real-dev_query_",test_set="marcodev",split_idx=i,d2q_eval=True,d2q_qrels=args.d2q_task_marco_dev_qrels)
                    if i != args.testing_split_idx:
                        d2q_marcodev_ndcg += d2q_ndcg_10_dev_split_i
                    eval_dict_todump[f'd2q_marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i
                logger.info(f"average marco dev d2q task { d2q_marcodev_ndcg /(args.dev_split_num -1)}")
            else:
                d2q_marcodev_ndcg = evaluation(real_dev_ANCE_ids,dev_query_embedding2id ,real_dev_d2q_I,real_dev_d2q_D,
                    trec_prefix="d2q-dual-task_real-dev_query_",test_set="marcodev",split_idx=-1,d2q_eval=True,d2q_qrels=args.d2q_task_marco_dev_qrels)
            
            eval_dict_todump['d2q_marcodev_ndcg'] = d2q_marcodev_ndcg

        if args.dual_training:
            # do this before completely truncating the query embedding 
            logger.info("***** Do ANN Index for dual d2q task *****")
            top_k = args.topk_training
            faiss.omp_set_num_threads(args.faiss_omp_num_threads)
            logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus()))
            cpu_index = faiss.IndexFlatIP(query_embedding.shape[1])
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(query_embedding)
            logger.info("***** Done building ANN Index for dual d2q task *****")
            

            logger.info("***** use ANCE id to construct positive passage embedding index *****")
            training_query_positive_id_inversed = {} # {v:k for k,v in training_query_positive_id.items()} # doc_id : query_id
            for k in training_query_positive_id:
                pos_pid=training_query_positive_id[k]
                if pos_pid not in training_query_positive_id_inversed:
                    training_query_positive_id_inversed[pos_pid]=[k]
                else:
                    training_query_positive_id_inversed[pos_pid].append(k)
            
            possitive_training_passage_id = [ training_query_positive_id[t] for t in query_embedding2id[q_start_idx:q_end_idx]] # 
            effective_p_id = set(possitive_training_passage_id)

            if "BM25_retrieval" == args.query_likelihood_strategy:
                passage_negative_queries = {}
                logger.info("***** loading negative queries from BM25 search result *****")
                with open(args.bm25_top_d2q_path,"r") as f:
                    for line in f:
                        pid,qid,rank = line.strip().split("\t")
                        pid = int(pid)
                        qid = int(qid)
                        if (pid in effective_p_id) and (qid not in training_query_positive_id_inversed[pid]):
                            if pid not in passage_negative_queries:
                                passage_negative_queries[pid]=[qid]
                            elif len(passage_negative_queries[pid]) < args.topk_training_d2q:
                                passage_negative_queries[pid].append(qid)
                                
                logger.info(f"***** shuffle and pick {args.negative_sample} negative queries *****")
                for pid in passage_negative_queries:
                    random.shuffle(passage_negative_queries[pid])
                    passage_negative_queries[pid]=passage_negative_queries[pid][:args.negative_sample]

            else:
                # compatible with MaxP
                possitive_training_passage_id_embidx=[]
                possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs 
                possitive_training_passage_id_emb_counts=0
                for pos_pid in possitive_training_passage_id:
                    embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten()
                    possitive_training_passage_id_embidx.append(embidx)
                    possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0]))
                    possitive_training_passage_id_emb_counts += embidx.shape[0]
                possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0)

                if not args.split_ann_search:
                    D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) 
                else:
                    if args.query_likelihood_strategy == "positive_doc":
                        positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",)
                    assert positive_p_embs.shape[0] == len(possitive_training_passage_id)
                    D, I = index.search(positive_p_embs, args.topk_training_d2q) 
                    positive_p_embs = None
                    del positive_p_embs

                index.reset()

                logger.info("***** Finish ANN searching for dual d2q task, construct  *****")

                passage_negative_queries = GenerateNegativeQueryID(
                    args,
                    possitive_training_passage_id,
                    query_embedding2id,
                    training_query_positive_id_inversed,
                    I,
                    effective_p_id,
                    pid2pos_pas_embs_idxs = possitive_training_passage_id_to_subset_embidx,
                    Scores_nearest_neighbor=D if "multi_chunk" in args.model_type else None)

            logger.info("***** Done ANN searching for negative queries *****")

        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info("***** Construct ANN Triplet *****")
        prefix =  "ann_grouped_training_data_" if args.grouping_ann_data > 0  else "ann_training_data_"
        train_data_output_path = os.path.join(
            args.output_dir, prefix + str(output_num))
        
    
        if args.grouping_ann_data > 0 :
            with open(train_data_output_path, 'w') as f:
                query_range = list(range(query_range_number))
                random.shuffle(query_range)
                
                counting=0
                pos_q_group={}
                pos_d_group={}
                neg_D_group={} # {0:[], 1:[], 2:[]...}
                if args.dual_training:
                    neg_Q_group={}
                for query_idx in query_range:
                    query_id = query_embedding2id[query_idx]
                    if query_id not in effective_q_id or query_id not in training_query_positive_id:
                        continue
                    pos_pid = training_query_positive_id[query_id]
                    
                    pos_q_group[counting]=int(query_id)
                    pos_d_group[counting]=int(pos_pid)

                    neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]]
                    if args.dual_training:
                        neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]]

                    counting +=1

                    if counting >= args.grouping_ann_data:
                        jsonline_dict={}
                        jsonline_dict["pos_q_group"]=pos_q_group
                        jsonline_dict["pos_d_group"]=pos_d_group
                        jsonline_dict["neg_D_group"]=neg_D_group
                        
                        if args.dual_training:
                            jsonline_dict["neg_Q_group"]=neg_Q_group

                        f.write(f"{json.dumps(jsonline_dict)}\n")

                        counting=0
                        pos_q_group={}
                        pos_d_group={}
                        neg_D_group={} # {0:[], 1:[], 2:[]...}
                        if args.dual_training:
                            neg_Q_group={}

        else:
            with open(train_data_output_path, 'w') as f:
                query_range = list(range(query_range_number))
                random.shuffle(query_range)
                # old version implementation 
                for query_idx in query_range:
                    query_id = query_embedding2id[query_idx]
                    if query_id not in effective_q_id or query_id not in training_query_positive_id:
                        continue
                    pos_pid = training_query_positive_id[query_id]
                    if not args.dual_training:
                        f.write(
                            "{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id])))
                    else:
                        if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed:
                            continue
                        f.write(
                            "{}\t{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id]),
                                ','.join(
                                    str(neg_qid) for neg_qid in passage_negative_queries[pos_pid])
                            )
                        )
                    
        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        
        if args.data_type==0:
            eval_dict_todump['trec2019_ndcg']=trec2019_ndcg
        with open(ndcg_output_path, 'w') as f:
            json.dump(eval_dict_todump, f)

        return None #dev_ndcg, num_queries_dev
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        latest_step_num,
        training_query_positive_id=None,
        dev_query_positive_id=None,
        ):
    config, tokenizer, model = load_model(args, checkpoint_path)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    if args.inference:
        return

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
    logger.info("***** Done passage inference *****")

    

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(
            args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (
            effective_idx == (
                chunk_factor -
                1)) else (
            q_start_idx +
            queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())
        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding2id,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, pos_pid, ','.join(
                            str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev
示例#13
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0

    save_no = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0 and global_step % args.save_steps < args.save_steps / 20:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)
                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                dev_ndcg = ndcg_json['ndcg']
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info("Total ann queries: %d", len(ann_training_data))
                if args.triplet:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                else:
                    train_dataset = StreamingDataset(
                        ann_training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data))

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    tb_writer.add_scalar("dev_ndcg", dev_ndcg,
                                         ann_checkpoint_no)
                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        batch = tuple(t.to(args.device) for t in batch)
        step += 1

        if args.triplet:
            inputs = {
                "query_ids": batch[0].long(),
                "attention_mask_q": batch[1].long(),
                "input_ids_a": batch[3].long(),
                "attention_mask_a": batch[4].long(),
                "input_ids_b": batch[6].long(),
                "attention_mask_b": batch[7].long()
            }
        else:
            inputs = {
                "input_ids_a": batch[0].long(),
                "attention_mask_a": batch[1].long(),
                "input_ids_b": batch[3].long(),
                "attention_mask_b": batch[4].long(),
                "labels": batch[6]
            }

        # sync gradients only at gradient accumulation step
        if step % args.gradient_accumulation_steps == 0:
            outputs = model(**inputs)
        else:
            with model.no_sync():
                outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

                save_no += 1

                if save_no > 1:
                    ann_no, ann_path, ndcg_json = get_latest_ann_data(
                        args.ann_dir)
                    while (ann_no == last_ann_no):
                        print("Waiting for new ann_data. Sleeping for 1hr!!")
                        time.sleep(3600)
                        ann_no, ann_path, ndcg_json = get_latest_ann_data(
                            args.ann_dir)

            dist.barrier()

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
示例#14
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(
        1, args.n_gpu)  #nll loss for query
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * (
        torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer = get_optimizer(
        args,
        model,
        weight_decay=args.weight_decay,
    )

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    tr_loss = 0.0
    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    dev_ndcg = 0
    step = 0
    iter_count = 0

    scheduler = get_linear_schedule_with_warmup(
        optimizer,
        num_warmup_steps=args.warmup_steps,
        num_training_steps=args.max_steps)

    global_step = 0
    if args.model_name_or_path != "bert-base-uncased":
        saved_state = load_states_from_checkpoint(args.model_name_or_path)
        global_step = _load_saved_state(model, optimizer, scheduler,
                                        saved_state)
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

        #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
        #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
        #if is_first_worker():
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
        #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
    print(args.num_epoch)
    #step = global_step
    print(step, args.max_steps, global_step)

    global_step = 0
    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:

            if args.num_epoch == 0:
                #print('yes')
                # check if new ann training data is availabe
                ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
                #print(ann_path)
                #print(ann_no)
                #print(ndcg_json)
                if ann_path is not None and ann_no != last_ann_no:
                    logger.info("Training on new add data at %s", ann_path)
                    time.sleep(180)
                    with open(ann_path, 'r') as f:
                        #print(ann_path)
                        ann_training_data = f.readlines()
                    logger.info("Training data line count: %d",
                                len(ann_training_data))
                    ann_training_data = [
                        l for l in ann_training_data
                        if len(l.split('\t')[2].split(',')) > 1
                    ]
                    logger.info("Filtered training data line count: %d",
                                len(ann_training_data))
                    #ann_checkpoint_path = ndcg_json['checkpoint']
                    #ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                    aligned_size = (len(ann_training_data) //
                                    args.world_size) * args.world_size
                    ann_training_data = ann_training_data[:aligned_size]

                    logger.info("Total ann queries: %d",
                                len(ann_training_data))
                    if args.triplet:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTripletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset, batch_size=args.train_batch_size)
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetTrainingDataProcessingFn(
                                args, query_cache, passage_cache))
                        train_dataloader = DataLoader(
                            train_dataset,
                            batch_size=args.train_batch_size * 2)
                    train_dataloader_iter = iter(train_dataloader)

                    # re-warmup
                    if not args.single_warmup:
                        scheduler = get_linear_schedule_with_warmup(
                            optimizer,
                            num_warmup_steps=args.warmup_steps,
                            num_training_steps=len(ann_training_data))

                    if args.local_rank != -1:
                        dist.barrier()

                    if is_first_worker():
                        # add ndcg at checkpoint step used instead of current step
                        #tb_writer.add_scalar("retrieval_accuracy/top20_nq", ndcg_json['top20'], ann_checkpoint_no)
                        #tb_writer.add_scalar("retrieval_accuracy/top100_nq", ndcg_json['top100'], ann_checkpoint_no)
                        #if 'top20_trivia' in ndcg_json:
                        #    tb_writer.add_scalar("retrieval_accuracy/top20_trivia", ndcg_json['top20_trivia'], ann_checkpoint_no)
                        #    tb_writer.add_scalar("retrieval_accuracy/top100_trivia", ndcg_json['top100_trivia'], ann_checkpoint_no)
                        if last_ann_no != -1:
                            tb_writer.add_scalar("epoch", last_ann_no,
                                                 global_step - 1)
                        tb_writer.add_scalar("epoch", ann_no, global_step)
                    last_ann_no = ann_no
            elif step == 0:

                train_data_path = os.path.join(args.data_dir, "train-data")
                with open(train_data_path, 'r') as f:
                    training_data = f.readlines()
                if args.triplet:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTripletTrainingDataProcessingFn(
                            args, query_cache, passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size)
                else:
                    train_dataset = StreamingDataset(
                        training_data,
                        GetTrainingDataProcessingFn(args, query_cache,
                                                    passage_cache))
                    train_dataloader = DataLoader(
                        train_dataset, batch_size=args.train_batch_size * 2)
                all_batch = [b for b in train_dataloader]
                logger.info("Total batch count: %d", len(all_batch))
                train_dataloader_iter = iter(train_dataloader)

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            if args.num_epoch != 0:
                iter_count += 1
                if is_first_worker():
                    tb_writer.add_scalar("epoch", iter_count - 1,
                                         global_step - 1)
                    tb_writer.add_scalar("epoch", iter_count, global_step)
            #nq_dev_nll_loss, nq_correct_ratio = evaluate_dev(args, model, passage_cache)
            #dev_nll_loss_trivia, correct_ratio_trivia = evaluate_dev(args, model, passage_cache, "-trivia")
            #if is_first_worker():
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss", nq_dev_nll_loss, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio", nq_correct_ratio, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/dev_nll_loss_trivia", dev_nll_loss_trivia, global_step)
            #    tb_writer.add_scalar("dev_nll_loss/correct_ratio_trivia", correct_ratio_trivia, global_step)
            ann_no, ann_path, ndcg_json = get_latest_ann_data(args.ann_dir)
            if ann_path is not None:
                with open(ann_path, 'r') as f:
                    print(ann_path)
                    ann_training_data = f.readlines()
                logger.info("Training data line count: %d",
                            len(ann_training_data))
                ann_training_data = [
                    l for l in ann_training_data
                    if len(l.split('\t')[2].split(',')) > 1
                ]
                logger.info("Filtered training data line count: %d",
                            len(ann_training_data))

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]
                train_dataset = StreamingDataset(
                    ann_training_data,
                    GetTrainingDataProcessingFn(args, query_cache,
                                                passage_cache))
                train_dataloader = DataLoader(
                    train_dataset, batch_size=args.train_batch_size * 2)

            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)
            dist.barrier()

        if args.num_epoch != 0 and iter_count > args.num_epoch:
            break

        step += 1
        if args.triplet:
            loss = triplet_fwd_pass(args, model, batch)
        else:
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if step % args.gradient_accumulation_steps == 0:
                loss.backward()
            else:
                with model.no_sync():
                    loss.backward()

        tr_loss += loss.item()
        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                loss_scalar = tr_loss / args.logging_steps
                learning_rate_scalar = scheduler.get_lr()[0]
                logs["learning_rate"] = learning_rate_scalar
                logs["loss"] = loss_scalar
                tr_loss = 0

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                _save_checkpoint(args, model, optimizer, scheduler,
                                 global_step)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
示例#15
0
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data,
                     latest_step_num):
    #passage_text, train_pos_id, train_answers, test_answers, test_pos_id = preloaded_data
    #print(test_pos_id.shape)
    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir,
                                               "train-sec-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "query_sec_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-sec-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "dev_sec_query_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=False),
            "passage_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=False,
            load_cache=False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        train_pos_id, test_pos_id = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        #num_q = passage_embedding.shape[0]

        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)

        logger.info('Data indexing completed.')

        logger.info("Start searching for query embedding with length %d",
                    len(query_embedding))

        II = list()
        for i in range(15):
            _, idx = cpu_index.search(query_embedding[i * 5000:(i + 1) * 5000],
                                      top_k)  #I: [number of queries, topk]
            II.append(idx)
            logger.info("Split done %d", i)

        I = II[0]
        for i in range(1, 15):
            I = np.concatenate((I, II[i]), axis=0)

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        #logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = dict()
        for query_idx in range(I.shape[0]):
            query_id = query_embedding2id[query_idx]
            doc_ids = list()

            doc_ids = [passage_embedding2id[pidx] for pidx in I[query_idx]]

            neg_docs = list()
            for doc_id in doc_ids:
                pos_id = [
                    int(p_id) for p_id in train_pos_id[query_id].split(',')
                ]

                if doc_id in pos_id:
                    continue
                if doc_id in neg_docs:
                    continue
                neg_docs.append(doc_id)

            query_negative_passage[query_id] = neg_docs

        logger.info("Done generating negative passages, output length %d",
                    len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(
                    query_id, pos_pid, ','.join(
                        str(neg_pid)
                        for neg_pid in query_negative_passage[query_id])))

        _, dev_I = cpu_index.search(dev_query_embedding,
                                    10)  #I: [number of queries, topk]

        top_k_hits = validate(test_pos_id, dev_I, dev_query_embedding2id,
                              passage_embedding2id)
示例#16
0
def train(args, model, tokenizer, query_cache, passage_cache):
    """ Train the model """
    logger.info("Training/evaluation parameters %s", args)
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    def optimizer_to(optim, device):
        for param in optim.state.values():
            # Not sure there are any global tensors in the state dict
            if isinstance(param, torch.Tensor):
                param.data = param.data.to(device)
                if param._grad is not None:
                    param._grad.data = param._grad.data.to(device)
            elif isinstance(param, dict):
                for subparam in param.values():
                    if isinstance(subparam, torch.Tensor):
                        subparam.data = subparam.data.to(device)
                        if subparam._grad is not None:
                            subparam._grad.data = subparam._grad.data.to(
                                device)

    torch.cuda.empty_cache()

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(
            os.path.join(args.model_name_or_path,
                         "optimizer.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt"),
                       map_location='cpu'))
    optimizer_to(optimizer, args.device)

    model.to(args.device)

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train
    logger.info("***** Running training *****")
    logger.info("  Max steps = %d", args.max_steps)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)

    global_step = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model
        # path
        if "-" in args.model_name_or_path:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
        else:
            global_step = 0
        logger.info(
            "  Continuing training from checkpoint, will skip to saved global_step"
        )
        logger.info("  Continuing training from global step %d", global_step)

    is_hypersphere_training = (args.hyper_align_weight > 0
                               or args.hyper_unif_weight > 0)
    if is_hypersphere_training:
        logger.info(
            f"training with hypersphere property regularization, align weight {args.hyper_align_weight}, unif weight {args.hyper_unif_weight}"
        )
    if not args.dual_training:
        args.dual_loss_weight = 0.0

    tr_loss_dict = {}

    model.zero_grad()
    model.train()
    set_seed(args)  # Added here for reproductibility

    last_ann_no = -1
    train_dataloader = None
    train_dataloader_iter = None
    # dev_ndcg = 0
    step = 0

    if args.single_warmup:
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=args.max_steps)
        if os.path.isfile(os.path.join(
                args.model_name_or_path,
                "scheduler.pt")) and args.load_optimizer_scheduler:
            # Load in optimizer and scheduler states
            scheduler.load_state_dict(
                torch.load(
                    os.path.join(args.model_name_or_path, "scheduler.pt")))

    while global_step < args.max_steps:

        if step % args.gradient_accumulation_steps == 0 and global_step % args.logging_steps == 0:
            # check if new ann training data is availabe
            ann_no, ann_path, ndcg_json = get_latest_ann_data(
                args.ann_dir, is_grouped=(args.grouping_ann_data > 0))
            if ann_path is not None and ann_no != last_ann_no:
                logger.info("Training on new add data at %s", ann_path)

                time.sleep(30)  # wait until transmission finished

                with open(ann_path, 'r') as f:
                    ann_training_data = f.readlines()
                # marcodev_ndcg = ndcg_json['marcodev_ndcg']
                logging.info(f"loading:\n{ndcg_json}")
                ann_checkpoint_path = ndcg_json['checkpoint']
                ann_checkpoint_no = get_checkpoint_no(ann_checkpoint_path)

                aligned_size = (len(ann_training_data) //
                                args.world_size) * args.world_size
                ann_training_data = ann_training_data[:aligned_size]

                logger.info(
                    "Total ann queries: %d",
                    len(ann_training_data) if args.grouping_ann_data < 0 else
                    len(ann_training_data) * args.grouping_ann_data)

                if args.grouping_ann_data > 0:
                    if args.polling_loaded_data_batch_from_group:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_polling(
                                args, query_cache, passage_cache))
                    else:
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetGroupedTrainingDataProcessingFn_origin(
                                args, query_cache, passage_cache))
                else:
                    if not args.dual_training:
                        if args.triplet:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTripletTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                        else:
                            train_dataset = StreamingDataset(
                                ann_training_data,
                                GetTrainingDataProcessingFn(
                                    args, query_cache, passage_cache))
                    else:
                        # return quadruplet
                        train_dataset = StreamingDataset(
                            ann_training_data,
                            GetQuadrapuletTrainingDataProcessingFn(
                                args, query_cache, passage_cache))

                train_dataloader = DataLoader(train_dataset,
                                              batch_size=args.train_batch_size)
                train_dataloader_iter = iter(train_dataloader)

                # re-warmup
                if not args.single_warmup:
                    scheduler = get_linear_schedule_with_warmup(
                        optimizer,
                        num_warmup_steps=args.warmup_steps,
                        num_training_steps=len(ann_training_data)
                        if args.grouping_ann_data < 0 else
                        len(ann_training_data) * args.grouping_ann_data)

                if args.local_rank != -1:
                    dist.barrier()

                if is_first_worker():
                    # add ndcg at checkpoint step used instead of current step
                    for key in ndcg_json:
                        if "marcodev" in key:
                            tb_writer.add_scalar(key, ndcg_json[key],
                                                 ann_checkpoint_no)

                    if 'trec2019_ndcg' in ndcg_json:
                        tb_writer.add_scalar("trec2019_ndcg",
                                             ndcg_json['trec2019_ndcg'],
                                             ann_checkpoint_no)

                    if last_ann_no != -1:
                        tb_writer.add_scalar("epoch", last_ann_no,
                                             global_step - 1)
                    tb_writer.add_scalar("epoch", ann_no, global_step)
                last_ann_no = ann_no

        try:
            batch = next(train_dataloader_iter)
        except StopIteration:
            logger.info("Finished iterating current dataset, begin reiterate")
            train_dataloader_iter = iter(train_dataloader)
            batch = next(train_dataloader_iter)

        # original way
        if args.grouping_ann_data <= 0:
            batch = tuple(t.to(args.device) for t in batch)
            if args.triplet:
                inputs = {
                    "query_ids": batch[0].long(),
                    "attention_mask_q": batch[1].long(),
                    "input_ids_a": batch[3].long(),
                    "attention_mask_a": batch[4].long(),
                    "input_ids_b": batch[6].long(),
                    "attention_mask_b": batch[7].long()
                }
                if args.dual_training:
                    inputs["neg_query_ids"] = batch[9].long()
                    inputs["attention_mask_neg_query"] = batch[10].long()
                    inputs["prime_loss_weight"] = args.prime_loss_weight
                    inputs["dual_loss_weight"] = args.dual_loss_weight
            else:
                inputs = {
                    "input_ids_a": batch[0].long(),
                    "attention_mask_a": batch[1].long(),
                    "input_ids_b": batch[3].long(),
                    "attention_mask_b": batch[4].long(),
                    "labels": batch[6]
                }
        else:
            # the default collate_fn will convert item["q_pos"] into batch format ... I guess
            inputs = {
                "query_ids": batch["q_pos"][0].to(args.device).long(),
                "attention_mask_q": batch["q_pos"][1].to(args.device).long(),
                "input_ids_a": batch["d_pos"][0].to(args.device).long(),
                "attention_mask_a": batch["d_pos"][1].to(args.device).long(),
                "input_ids_b": batch["d_neg"][0].to(args.device).long(),
                "attention_mask_b": batch["d_neg"][1].to(args.device).long(),
            }
            if args.dual_training:
                inputs["neg_query_ids"] = batch["q_neg"][0].to(
                    args.device).long()
                inputs["attention_mask_neg_query"] = batch["q_neg"][1].to(
                    args.device).long()
                inputs["prime_loss_weight"] = args.prime_loss_weight
                inputs["dual_loss_weight"] = args.dual_loss_weight

        inputs["temperature"] = args.temperature
        inputs["loss_objective"] = args.loss_objective_function

        if is_hypersphere_training:
            inputs["alignment_weight"] = args.hyper_align_weight
            inputs["uniformity_weight"] = args.hyper_unif_weight

        step += 1

        if args.local_rank != -1:
            # sync gradients only at gradient accumulation step
            if step % args.gradient_accumulation_steps == 0:
                outputs = model(**inputs)
            else:
                with model.no_sync():
                    outputs = model(**inputs)
        else:
            outputs = model(**inputs)
        # model outputs are always tuple in transformers (see doc)
        loss = outputs[0]

        loss_item_dict = outputs[1]

        if args.n_gpu > 1:
            loss = loss.mean(
            )  # mean() to average on multi-gpu parallel training
            for k in loss_item_dict:
                loss_item_dict[k] = loss_item_dict[k].mean()

        if args.gradient_accumulation_steps > 1:
            loss = loss / args.gradient_accumulation_steps
            for k in loss_item_dict:
                loss_item_dict[
                    k] = loss_item_dict[k] / args.gradient_accumulation_steps

        if args.fp16:
            with amp.scale_loss(loss, optimizer) as scaled_loss:
                scaled_loss.backward()
        else:
            if args.local_rank != -1:
                if step % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()
            else:
                loss.backward()

        def incremental_tr_loss(tr_loss_dict, loss_item_dict, total_loss):
            for k in loss_item_dict:
                if k not in tr_loss_dict:
                    tr_loss_dict[k] = loss_item_dict[k].item()
                else:
                    tr_loss_dict[k] += loss_item_dict[k].item()
            if "loss_total" not in tr_loss_dict:
                tr_loss_dict["loss_total"] = total_loss.item()
            else:
                tr_loss_dict["loss_total"] += total_loss.item()
            return tr_loss_dict

        tr_loss_dict = incremental_tr_loss(tr_loss_dict,
                                           loss_item_dict,
                                           total_loss=loss)

        if step % args.gradient_accumulation_steps == 0:
            if args.fp16:
                torch.nn.utils.clip_grad_norm_(amp.master_params(optimizer),
                                               args.max_grad_norm)
            else:
                torch.nn.utils.clip_grad_norm_(model.parameters(),
                                               args.max_grad_norm)

            optimizer.step()
            scheduler.step()  # Update learning rate schedule
            model.zero_grad()
            global_step += 1

            if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                logs = {}
                learning_rate_scalar = scheduler.get_lr()[0]

                logs["learning_rate"] = learning_rate_scalar

                for k in tr_loss_dict:
                    logs[k] = tr_loss_dict[k] / args.logging_steps
                tr_loss_dict = {}

                if is_first_worker():
                    for key, value in logs.items():
                        tb_writer.add_scalar(key, value, global_step)
                    logger.info(json.dumps({**logs, **{"step": global_step}}))

            if is_first_worker(
            ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                # Save model checkpoint
                output_dir = os.path.join(args.output_dir,
                                          "checkpoint-{}".format(global_step))
                if not os.path.exists(output_dir):
                    os.makedirs(output_dir)
                model_to_save = (
                    model.module if hasattr(model, "module") else model
                )  # Take care of distributed/parallel training
                model_to_save.save_pretrained(output_dir)
                tokenizer.save_pretrained(output_dir)

                torch.save(args, os.path.join(output_dir, "training_args.bin"))
                logger.info("Saving model checkpoint to %s", output_dir)

                torch.save(optimizer.state_dict(),
                           os.path.join(output_dir, "optimizer.pt"))
                torch.save(scheduler.state_dict(),
                           os.path.join(output_dir, "scheduler.pt"))
                logger.info("Saving optimizer and scheduler states to %s",
                            output_dir)

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step
示例#17
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    tr_acc, logging_acc = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            if not args.reset_iter:
                if steps_trained_in_current_epoch > 0:
                    steps_trained_in_current_epoch -= 1
                    continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"

            if (step + 1) % args.gradient_accumulation_steps == 0:

                outputs = model(*batch)
            else:
                with model.no_sync():
                    # print('???',*batch)
                    # assert 1==0
                    outputs = model(*batch)
            # model outputs are always tuple in transformers (see doc)
            loss = outputs[0]
            acc = outputs[1]
            #print('???',acc)
            if is_first_worker():
                print(*batch)
                assert 1 == 0

            if args.n_gpu > 1:
                loss = loss.mean(
                )  # mean() to average on multi-gpu parallel training
                acc = acc.float().mean()
                #print('???',acc)
            if args.gradient_accumulation_steps > 1:
                loss = loss / args.gradient_accumulation_steps
                acc = acc / args.gradient_accumulation_steps

            if args.fp16:
                with amp.scale_loss(loss, optimizer) as scaled_loss:
                    scaled_loss.backward()
            else:
                if (step + 1) % args.gradient_accumulation_steps == 0:
                    loss.backward()
                else:
                    with model.no_sync():
                        loss.backward()

            tr_loss += loss.item()
            tr_acc += acc.item()

            if (step + 1) % args.gradient_accumulation_steps == 0:
                if args.fp16:
                    torch.nn.utils.clip_grad_norm_(
                        amp.master_params(optimizer), args.max_grad_norm)
                else:
                    torch.nn.utils.clip_grad_norm_(model.parameters(),
                                                   args.max_grad_norm)

                optimizer.step()
                scheduler.step()  # Update learning rate schedule
                model.zero_grad()
                global_step += 1

                if is_first_worker(
                ) and args.save_steps > 0 and global_step % args.save_steps == 0:
                    # Save model checkpoint
                    output_dir = os.path.join(
                        args.output_dir, "checkpoint-{}".format(global_step))
                    if not os.path.exists(output_dir):
                        os.makedirs(output_dir)
                    if 'fairseq' not in args.train_model_type:
                        model_to_save = (
                            model.module if hasattr(model, "module") else model
                        )  # Take care of distributed/parallel training
                        model_to_save.save_pretrained(output_dir)
                        tokenizer.save_pretrained(output_dir)
                    else:
                        torch.save(model.state_dict(),
                                   os.path.join(output_dir, 'model.pt'))

                    torch.save(args,
                               os.path.join(output_dir, "training_args.bin"))
                    logger.info("Saving model checkpoint to %s", output_dir)

                    torch.save(optimizer.state_dict(),
                               os.path.join(output_dir, "optimizer.pt"))
                    torch.save(scheduler.state_dict(),
                               os.path.join(output_dir, "scheduler.pt"))
                    logger.info("Saving optimizer and scheduler states to %s",
                                output_dir)
                dist.barrier()

                if args.logging_steps > 0 and global_step % args.logging_steps == 0:
                    logs = {}
                    if args.evaluate_during_training and global_step % (
                            args.logging_steps_per_eval *
                            args.logging_steps) == 0:
                        model.eval()
                        reranking_mrr, full_ranking_mrr = passage_dist_eval(
                            args, model, tokenizer)
                        if is_first_worker():
                            print("Reranking/Full ranking mrr: {0}/{1}".format(
                                str(reranking_mrr), str(full_ranking_mrr)))
                            mrr_dict = {
                                "reranking": float(reranking_mrr),
                                "full_raking": float(full_ranking_mrr)
                            }
                            tb_writer.add_scalars("mrr", mrr_dict, global_step)
                            print(args.output_dir)

                    loss_scalar = (tr_loss - logging_loss) / args.logging_steps
                    learning_rate_scalar = scheduler.get_lr()[0]
                    logs["learning_rate"] = learning_rate_scalar
                    logs["loss"] = loss_scalar
                    logging_loss = tr_loss

                    acc_scalar = (tr_acc - logging_acc) / args.logging_steps
                    logs["acc"] = acc_scalar
                    logging_acc = tr_acc

                    if is_first_worker():
                        for key, value in logs.items():
                            print(key, type(value))
                            tb_writer.add_scalar(key, value, global_step)
                        tb_writer.add_scalar("epoch", m_epoch, global_step)
                        print(json.dumps({**logs, **{"step": global_step}}))
                    dist.barrier()

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
示例#18
0
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num):

    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')

    query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True)
    dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="test-query", emb_prefix="dev_query_", is_query_inference=True)
    dev_query_embedding_trivia, dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-test-query", emb_prefix="trivia_dev_query_", is_query_inference=True)
    real_dev_query_embedding, real_dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-qas-query", emb_prefix="real-dev_query_", is_query_inference=True)
    real_dev_query_embedding_trivia, real_dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-dev-qas-query", emb_prefix="trivia_real-dev_query_", is_query_inference=True)

    # passage_embedding == None, if args.split_ann_search == True
    passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb= not args.split_ann_search)
    
    
    if args.gpu_index:
        del model  # leave gpu for faiss
        torch.cuda.empty_cache()
        time.sleep(10)

    if args.local_rank != -1:
        dist.barrier()

    # if None, reloading
    if passage_embedding2id is None and is_first_worker():
        _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=None, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False)
        logger.info(f"document id size: {passage_embedding2id.shape}")

    if is_first_worker():
        # passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data
        passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, dev_answers, dev_answers_trivia = preloaded_data

        if not args.split_ann_search:
            dim = passage_embedding.shape[1]
            print('passage embedding shape: ' + str(passage_embedding.shape))
            top_k = args.topk_training 
            faiss.omp_set_num_threads(16)
            cpu_index = faiss.IndexFlatIP(dim)
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(passage_embedding)
            logger.info("***** Done ANN Index *****")
            _, dev_I = index.search(dev_query_embedding, 100) #I: [number of queries, topk]            
            _, dev_I_trivia = index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk]
            logger.info("Start searching for query embedding with length %d", len(query_embedding))
            _, I = index.search(query_embedding, top_k) #I: [number of queries, topk]
        else:
            _, dev_I_trivia, real_dev_D, real_dev_I  = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    top_k_dev=100,
                    top_k=args.topk_training,
                    checkpoint_step=checkpoint_step,
                    dev_query_emb=dev_query_embedding_trivia,
                    train_query_emb=real_dev_query_embedding,
                    emb_prefix="passage_",two_query_set=True,
            )
            dev_D, dev_I, _, I  = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    top_k_dev=100,
                    top_k=args.topk_training,
                    checkpoint_step=checkpoint_step,
                    dev_query_emb=dev_query_embedding,
                    train_query_emb=query_embedding,
                    emb_prefix="passage_",two_query_set=True,
            )
        save_trec_file(
            dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,
            trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-test_" + checkpoint_step + ".trec")),
            topN=100
        )
        save_trec_file(
            real_dev_query_embedding2id,passage_embedding2id,real_dev_I,real_dev_D,
            trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-dev_" + checkpoint_step + ".trec")),
            topN=100
        )
        # measure ANN mrr 
        top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id)
        real_dev_top_k_hits = validate(passage_text, dev_answers, real_dev_I, real_dev_query_embedding2id, passage_embedding2id)
        top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I_trivia, dev_query_embedding2id_trivia, passage_embedding2id)
        query_range_number = I.shape[0]
        json_dump_dict = {
            'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19],
            'dev_top20': real_dev_top_k_hits[19], 'dev_top100': real_dev_top_k_hits[99],  
            'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path, 'n_train_query':query_range_number,
        }
        logger.info(json_dump_dict)

        
        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id)

        logger.info("Done generating negative passages, output length %d", len(query_negative_passage))

        if args.dual_training:
            
            assert args.split_ann_search and args.gpu_index # hard set
            logger.info("***** Begin ANN Index for dual d2q task *****")
            top_k = args.topk_training
            faiss.omp_set_num_threads(args.faiss_omp_num_threads)
            logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus()))
            cpu_index = faiss.IndexFlatIP(query_embedding.shape[1])
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(query_embedding)
            logger.info("***** Done building ANN Index for dual d2q task *****")

            # train_pos_id : a list, idx -> int pid
            train_pos_id_inversed = {}
            for qidx in range(query_embedding2id.shape[0]):
                qid = query_embedding2id[qidx]
                pid = int(train_pos_id[qid])
                if pid not in train_pos_id_inversed:
                    train_pos_id_inversed[pid]=[qid]
                else:
                    train_pos_id_inversed[pid].append(qid)
            
            possitive_training_passage_id = [ train_pos_id[t] for t in query_embedding2id] # 
            # compatible with MaxP
            possitive_training_passage_id_embidx=[]
            possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs 
            possitive_training_passage_id_emb_counts=0
            for pos_pid in possitive_training_passage_id:
                embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten()
                possitive_training_passage_id_embidx.append(embidx)
                possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0]))
                possitive_training_passage_id_emb_counts += embidx.shape[0]
            possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0)
            
            if not args.split_ann_search:
                D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) 
            else:
                positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",)
                assert positive_p_embs.shape[0] == len(possitive_training_passage_id)
                D, I = index.search(positive_p_embs, args.topk_training_d2q) 
                positive_p_embs = None
                del positive_p_embs
            index.reset()
            logger.info("***** Finish ANN searching for dual d2q task, construct  *****")
            passage_negative_queries = GenerateNegativeQueryID(args, passage_text,train_answers, query_embedding2id, passage_embedding2id[possitive_training_passage_id_embidx], closest_ans=I, training_query_positive_id_inversed=train_pos_id_inversed)
            logger.info("***** Done ANN searching for negative queries *****")


        logger.info("***** Construct ANN Triplet *****")
        prefix =  "ann_grouped_training_data_" if args.grouping_ann_data > 0  else "ann_training_data_"
        train_data_output_path = os.path.join(
            args.output_dir, prefix + str(output_num))
        query_range = list(range(query_range_number))
        random.shuffle(query_range)
        if args.grouping_ann_data > 0 :
            with open(train_data_output_path, 'w') as f:
                counting=0
                pos_q_group={}
                pos_d_group={}
                neg_D_group={} # {0:[], 1:[], 2:[]...}
                if args.dual_training:
                    neg_Q_group={}
                for query_idx in query_range: 
                    query_id = query_embedding2id[query_idx]
                    pos_pid = train_pos_id[query_id]
                    
                    pos_q_group[counting]=int(query_id)
                    pos_d_group[counting]=int(pos_pid)

                    neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]]
                    if args.dual_training:
                        neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]]
                    counting +=1
                    if counting >= args.grouping_ann_data:
                        jsonline_dict={}
                        jsonline_dict["pos_q_group"]=pos_q_group
                        jsonline_dict["pos_d_group"]=pos_d_group
                        jsonline_dict["neg_D_group"]=neg_D_group
                        
                        if args.dual_training:
                            jsonline_dict["neg_Q_group"]=neg_Q_group

                        f.write(f"{json.dumps(jsonline_dict)}\n")

                        counting=0
                        pos_q_group={}
                        pos_d_group={}
                        neg_D_group={} # {0:[], 1:[], 2:[]...}
                        if args.dual_training:
                            neg_Q_group={}
             
        else:
            # not support dualtraining
            with open(train_data_output_path, 'w') as f:
                for query_idx in query_range: 
                    query_id = query_embedding2id[query_idx]
                    # if not query_id in train_pos_id:
                    #     continue
                    pos_pid = train_pos_id[query_id]
                    
                    if not args.dual_training:
                        f.write(
                            "{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id])))
                    else:
                        # if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed:
                        #     continue
                        f.write(
                            "{}\t{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id]),
                                ','.join(
                                    str(neg_qid) for neg_qid in passage_negative_queries[pos_pid])
                            )
                        )
        ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump(json_dump_dict, f)
def generate_new_ann(args, checkpoint_path):
    if args.gpu_index:
        clean_faiss_gpu()
    if not args.not_load_model_for_inference:
        config, tokenizer, model = load_model(args, checkpoint_path)

    checkpoint_step = checkpoint_path.split('-')[-1].replace('/', '')

    def evaluation(dev_query_embedding2id,
                   passage_embedding2id,
                   dev_I,
                   dev_D,
                   trec_prefix="real-dev_query_",
                   test_set="trec2019",
                   split_idx=-1,
                   d2q_eval=False,
                   d2q_qrels=None):
        if d2q_eval:
            qrels = d2q_qrels
        else:
            if args.data_type == 0:
                if not d2q_eval:
                    if test_set == "marcodev":
                        qrels = "../data/raw_data/msmarco-docdev-qrels.tsv"
                    elif test_set == "trec2019":
                        qrels = "../data/raw_data/2019qrels-docs.txt"
            elif args.data_type == 1:
                if test_set == "marcodev":
                    qrels = "../data/raw_data/qrels.dev.small.tsv"
            else:
                logging.error("wrong data type")
                exit()
        trec_path = os.path.join(args.output_dir,
                                 trec_prefix + str(checkpoint_step) + ".trec")
        save_trec_file(dev_query_embedding2id,
                       passage_embedding2id,
                       dev_I,
                       dev_D,
                       trec_save_path=trec_path,
                       topN=200)
        convert_trec_to_MARCO_id(data_type=args.data_type,
                                 test_set=test_set,
                                 processed_data_dir=args.data_dir,
                                 trec_path=trec_path,
                                 d2q_reversed_trec_file=d2q_eval)

        trec_path = trec_path.replace(".trec", ".formatted.trec")
        met = Metric()
        if split_idx >= 0:
            split_file_path = qrels + f"{args.dev_split_num}_fold.split_dict"
            with open(split_file_path, 'rb') as f:
                split = pickle.load(f)
        else:
            split = None

        ndcg10 = met.get_metric(qrels, trec_path, 'ndcg_cut_10', split,
                                split_idx)
        mrr10 = met.get_mrr(qrels, trec_path, 'mrr_cut_10', split, split_idx)
        mrr100 = met.get_mrr(qrels, trec_path, 'mrr_cut_100', split, split_idx)

        logging.info(
            f" evaluation for {test_set}, trec_file {trec_path}, split_idx {split_idx} \
            ndcg_cut_10 : {ndcg10}, \
            mrr_cut_10 : {mrr10}, \
            mrr_cut_100 : {mrr100}")

        return ndcg10

    # Inference
    if args.data_type == 0:
        # TREC DL 2019 evalset
        trec2019_query_embedding, trec2019_query_embedding2id = inference_or_load_embedding(
            args=args,
            logger=logger,
            model=model,
            checkpoint_path=checkpoint_path,
            text_data_prefix="dev-query",
            emb_prefix="dev_query_",
            is_query_inference=True)  # it's trec-dl testset actually
    dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(
        args=args,
        logger=logger,
        model=model,
        checkpoint_path=checkpoint_path,
        text_data_prefix="real-dev-query",
        emb_prefix="real-dev_query_",
        is_query_inference=True)
    query_embedding, query_embedding2id = inference_or_load_embedding(
        args=args,
        logger=logger,
        model=model,
        checkpoint_path=checkpoint_path,
        text_data_prefix="train-query",
        emb_prefix="query_",
        is_query_inference=True)
    if not args.split_ann_search:
        # merge all passage
        passage_embedding, passage_embedding2id = inference_or_load_embedding(
            args=args,
            logger=logger,
            model=model,
            checkpoint_path=checkpoint_path,
            text_data_prefix="passages",
            emb_prefix="passage_",
            is_query_inference=False)
    else:
        # keep id only
        _, passage_embedding2id = inference_or_load_embedding(
            args=args,
            logger=logger,
            model=model,
            checkpoint_path=checkpoint_path,
            text_data_prefix="passages",
            emb_prefix="passage_",
            is_query_inference=False,
            load_emb=False)

    # FirstP shape,
    # passage_embedding: [[vec_0], [vec_1], [vec_2], [vec_3] ...],
    # passage_embedding2id: [id0, id1, id2, id3, ...]

    # MaxP shape,
    # passage_embedding: [[vec_0_0], [vec_0_1],[vec_0_2],[vec_0_3],[vec_1_0],[vec_1_1] ...],
    # passage_embedding2id: [id0, id0, id0, id0, id1, id1 ...]
    if args.gpu_index:
        del model  # leave gpu for faiss
        torch.cuda.empty_cache()
        time.sleep(10)

    if not is_first_worker():
        return
    else:
        if not args.split_ann_search:
            dim = passage_embedding.shape[1]
            print('passage embedding shape: ' + str(passage_embedding.shape))
            top_k = args.topk_training
            faiss.omp_set_num_threads(args.faiss_omp_num_threads)
            cpu_index = faiss.IndexFlatIP(dim)
            logger.info("***** Faiss: total {} gpus *****".format(
                faiss.get_num_gpus()))
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(passage_embedding)
            # for measure ANN mrr
            logger.info("search dev query")
            dev_D, dev_I = index.search(dev_query_embedding,
                                        100)  # I: [number of queries, topk]
            logger.info("finish")
            logger.info("search train query")
            D, I = index.search(query_embedding,
                                top_k)  # I: [number of queries, topk]
            logger.info("finish")
            index.reset()
        else:
            if args.data_type == 0:
                trec2019_D, trec2019_I, _, _ = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    checkpoint_step=checkpoint_step,
                    top_k_dev=200,
                    top_k=args.topk_training,
                    dev_query_emb=trec2019_query_embedding,
                    train_query_emb=None,
                    emb_prefix="passage_",
                    two_query_set=False,
                )
            dev_D, dev_I, D, I = document_split_faiss_index(
                logger=logger,
                args=args,
                checkpoint_step=checkpoint_step,
                top_k_dev=200,
                top_k=args.topk_training,
                dev_query_emb=dev_query_embedding,
                train_query_emb=query_embedding,
                emb_prefix="passage_")
            logger.info("***** seperately process indexing *****")

        logger.info("***** Done ANN Index *****")

        # dev_ndcg, num_queries_dev = EvalDevQuery(
        #     args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)
        logger.info("***** Begin evaluation *****")
        eval_dict_todump = {'checkpoint': checkpoint_path}

        if args.data_type == 0:
            trec2019_ndcg = evaluation(trec2019_query_embedding2id,
                                       passage_embedding2id,
                                       trec2019_I,
                                       trec2019_D,
                                       trec_prefix="dev_query_",
                                       test_set="trec2019")
        if args.dev_split_num > 0:
            marcodev_ndcg = 0.0
            for i in range(args.dev_split_num):
                ndcg_10_dev_split_i = evaluation(dev_query_embedding2id,
                                                 passage_embedding2id,
                                                 dev_I,
                                                 dev_D,
                                                 trec_prefix="real-dev_query_",
                                                 test_set="marcodev",
                                                 split_idx=i)
                if i != args.testing_split_idx:
                    marcodev_ndcg += ndcg_10_dev_split_i

                eval_dict_todump[
                    f'marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i

            logger.info(
                f"average marco dev { marcodev_ndcg /(args.dev_split_num -1)}")
        else:
            marcodev_ndcg = evaluation(dev_query_embedding2id,
                                       passage_embedding2id,
                                       dev_I,
                                       dev_D,
                                       trec_prefix="real-dev_query_",
                                       test_set="marcodev",
                                       split_idx=-1)

        eval_dict_todump['marcodev_ndcg'] = marcodev_ndcg
        if args.save_training_query_trec:
            logger.info(
                "***** Save the ANN searching for negative passages in trec file format *****"
            )
            trec_output_path = os.path.join(
                args.output_dir, "ann_training_query_retrieval_" +
                str(checkpoint_step) + ".trec")
            save_trec_file(query_embedding2id,
                           passage_embedding2id,
                           I,
                           D,
                           trec_output_path,
                           topN=args.topk_training)
            convert_trec_to_MARCO_id(data_type=args.data_type,
                                     test_set="training",
                                     processed_data_dir=args.data_dir,
                                     trec_path=trec_output_path,
                                     d2q_reversed_trec_file=False)

        logger.info("***** Done ANN searching for negative passages *****")

        if args.d2q_task_evaluation and args.d2q_task_marco_dev_qrels is not None:
            with open(os.path.join(args.data_dir, 'pid2offset.pickle'),
                      'rb') as f:
                pid2offset = pickle.load(f)
            real_dev_ANCE_ids = []
            with open(
                    args.d2q_task_marco_dev_qrels +
                    f"{args.dev_split_num}_fold.split_dict", "rb") as f:
                dev_d2q_split_dict = pickle.load(f)
            for i in dev_d2q_split_dict:
                for stringdocid in dev_d2q_split_dict[i]:
                    if args.data_type == 0:
                        real_dev_ANCE_ids.append(pid2offset[int(
                            stringdocid[1:])])
                    else:
                        real_dev_ANCE_ids.append(pid2offset[int(stringdocid)])
            real_dev_ANCE_ids = np.array(real_dev_ANCE_ids).flatten()
            real_dev_possitive_training_passage_id_embidx = []
            for dev_pos_pid in real_dev_ANCE_ids:
                embidx = np.asarray(
                    np.where(passage_embedding2id == dev_pos_pid)).flatten()
                real_dev_possitive_training_passage_id_embidx.append(embidx)
                # possitive_training_passage_id_to_subset_embidx[int(dev_pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0]))
                # possitive_training_passage_id_emb_counts += embidx.shape[0]
            real_dev_possitive_training_passage_id_embidx = np.concatenate(
                real_dev_possitive_training_passage_id_embidx, axis=0)
            del pid2offset
            if not args.split_ann_search:
                real_dev_positive_p_embs = passage_embedding[
                    real_dev_possitive_training_passage_id_embidx]
            else:
                real_dev_positive_p_embs = loading_possitive_document_embedding(
                    logger,
                    args.output_dir,
                    checkpoint_step,
                    real_dev_possitive_training_passage_id_embidx,
                    emb_prefix="passage_",
                )
            logger.info("***** d2q task evaluation *****")
            cpu_index = faiss.IndexFlatIP(dev_query_embedding.shape[1])
            index = cpu_index
            # index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(dev_query_embedding)
            real_dev_d2q_D, real_dev_d2q_I = index.search(
                real_dev_positive_p_embs, 200)
            if args.dev_split_num > 0:
                d2q_marcodev_ndcg = 0.0
                for i in range(args.dev_split_num):
                    d2q_ndcg_10_dev_split_i = evaluation(
                        real_dev_ANCE_ids,
                        dev_query_embedding2id,
                        real_dev_d2q_I,
                        real_dev_d2q_D,
                        trec_prefix="d2q-dual-task_real-dev_query_",
                        test_set="marcodev",
                        split_idx=i,
                        d2q_eval=True,
                        d2q_qrels=args.d2q_task_marco_dev_qrels)
                    if i != args.testing_split_idx:
                        d2q_marcodev_ndcg += d2q_ndcg_10_dev_split_i
                    eval_dict_todump[
                        f'd2q_marcodev_split_{i}_ndcg_cut_10'] = ndcg_10_dev_split_i
                logger.info(
                    f"average marco dev d2q task { d2q_marcodev_ndcg /(args.dev_split_num -1)}"
                )
            else:
                d2q_marcodev_ndcg = evaluation(
                    real_dev_ANCE_ids,
                    dev_query_embedding2id,
                    real_dev_d2q_I,
                    real_dev_d2q_D,
                    trec_prefix="d2q-dual-task_real-dev_query_",
                    test_set="marcodev",
                    split_idx=-1,
                    d2q_eval=True,
                    d2q_qrels=args.d2q_task_marco_dev_qrels)

            eval_dict_todump['d2q_marcodev_ndcg'] = d2q_marcodev_ndcg

        return None  #dev_ndcg, num_queries_dev
示例#20
0
def train(args, model, tokenizer, f, train_fn):
    """ Train the model """
    tb_writer = None
    if is_first_worker():
        tb_writer = SummaryWriter(log_dir=args.log_dir)

    args.train_batch_size = args.per_gpu_train_batch_size * max(1, args.n_gpu)
    real_batch_size = args.train_batch_size * args.gradient_accumulation_steps * \
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1)

    if args.max_steps > 0:
        t_total = args.max_steps
    else:
        t_total = args.expected_train_size // real_batch_size * args.num_train_epochs

    print('????t_total', t_total)
    # layerwise optimization for lamb
    optimizer_grouped_parameters = []
    layer_optim_params = set()
    for layer_name in [
            "roberta.embeddings", "score_out", "downsample1", "downsample2",
            "downsample3", "embeddingHead"
    ]:
        layer = getattr_recursive(model, layer_name)
        if layer is not None:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    if getattr_recursive(model, "roberta.encoder.layer") is not None:
        for layer in model.roberta.encoder.layer:
            optimizer_grouped_parameters.append({"params": layer.parameters()})
            for p in layer.parameters():
                layer_optim_params.add(p)
    # if getattr_recursive(model, "roberta.encoder.layer") is not None:
    #     for layer in model.roberta.encoder.layer:
    #         optimizer_grouped_parameters.append({"params": layer.parameters()})
    #         for p in layer.parameters():
    #             layer_optim_params.add(p)

    optimizer_grouped_parameters.append({
        "params":
        [p for p in model.parameters() if p not in layer_optim_params]
    })

    if args.optimizer.lower() == "lamb":
        optimizer = Lamb(optimizer_grouped_parameters,
                         lr=args.learning_rate,
                         eps=args.adam_epsilon)
    elif args.optimizer.lower() == "adamw":
        optimizer = AdamW(optimizer_grouped_parameters,
                          lr=args.learning_rate,
                          eps=args.adam_epsilon)
    else:
        raise Exception(
            "optimizer {0} not recognized! Can only be lamb or adamW".format(
                args.optimizer))

    if args.scheduler.lower() == "linear":
        scheduler = get_linear_schedule_with_warmup(
            optimizer,
            num_warmup_steps=args.warmup_steps,
            num_training_steps=t_total)
    elif args.scheduler.lower() == "cosine":
        scheduler = CosineAnnealingLR(optimizer, t_total, 1e-8)
    else:
        raise Exception(
            "Scheduler {0} not recognized! Can only be linear or cosine".
            format(args.scheduler))

    # Check if saved optimizer or scheduler states exist
    if os.path.isfile(os.path.join(
            args.model_name_or_path, "optimizer.pt")) and os.path.isfile(
                os.path.join(
                    args.model_name_or_path,
                    "scheduler.pt")) and args.load_optimizer_scheduler:
        # Load in optimizer and scheduler states
        optimizer.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "optimizer.pt")))
        scheduler.load_state_dict(
            torch.load(os.path.join(args.model_name_or_path, "scheduler.pt")))

    if args.fp16:
        try:
            from apex import amp
        except ImportError:
            raise ImportError(
                "Please install apex from https://www.github.com/nvidia/apex to use fp16 training."
            )
        model, optimizer = amp.initialize(model,
                                          optimizer,
                                          opt_level=args.fp16_opt_level)

    # multi-gpu training (should be after apex fp16 initialization)
    if args.n_gpu > 1:
        model = torch.nn.DataParallel(model)

    # Distributed training (should be after apex fp16 initialization)
    if args.local_rank != -1:
        model = torch.nn.parallel.DistributedDataParallel(
            model,
            device_ids=[args.local_rank],
            output_device=args.local_rank,
            find_unused_parameters=True,
        )

    # Train!
    logger.info("***** Running training *****")
    logger.info("  Num Epochs = %d", args.num_train_epochs)
    logger.info("  Instantaneous batch size per GPU = %d",
                args.per_gpu_train_batch_size)
    logger.info(
        "  Total train batch size (w. parallel, distributed & accumulation) = %d",
        args.train_batch_size * args.gradient_accumulation_steps *
        (torch.distributed.get_world_size() if args.local_rank != -1 else 1),
    )
    logger.info("  Gradient Accumulation steps = %d",
                args.gradient_accumulation_steps)
    logger.info("  Total optimization steps = %d", t_total)

    global_step = 0
    epochs_trained = 0
    steps_trained_in_current_epoch = 0
    # Check if continuing training from a checkpoint
    if os.path.exists(args.model_name_or_path):
        # set global_step to gobal_step of last saved checkpoint from model path
        try:
            global_step = int(
                args.model_name_or_path.split("-")[-1].split("/")[0])
            epochs_trained = global_step // (args.expected_train_size //
                                             args.gradient_accumulation_steps)
            steps_trained_in_current_epoch = global_step % (
                args.expected_train_size // args.gradient_accumulation_steps)

            logger.info(
                "  Continuing training from checkpoint, will skip to saved global_step"
            )
            logger.info("  Continuing training from epoch %d", epochs_trained)
            logger.info("  Continuing training from global step %d",
                        global_step)
            logger.info("  Will skip the first %d steps in the first epoch",
                        steps_trained_in_current_epoch)
        except:
            logger.info("  Start training from a pretrained model")

    tr_loss, logging_loss = 0.0, 0.0
    model.zero_grad()
    train_iterator = trange(
        epochs_trained,
        int(args.num_train_epochs),
        desc="Epoch",
        disable=args.local_rank not in [-1, 0],
    )
    set_seed(args)  # Added here for reproductibility
    #print('???',args.local_rank)
    #assert 1==0, "?????"
    for m_epoch in train_iterator:
        f.seek(0)
        sds = StreamingDataset(f, train_fn)
        epoch_iterator = DataLoader(sds,
                                    batch_size=args.per_gpu_train_batch_size,
                                    num_workers=1)
        count = 0
        avg_cls_norm = 0
        loss_avg = 0
        for step, batch in tqdm(enumerate(epoch_iterator),
                                desc="Iteration",
                                disable=args.local_rank not in [-1, 0]):
            #assert 1==0, "?????"
            # Skip past any already trained steps if resuming training
            #assert 1==0, steps_trained_in_current_epoch
            # if not args.reset_iter:
            #     if steps_trained_in_current_epoch > 0:
            #         steps_trained_in_current_epoch -= 1
            #         continue

            model.train()
            batch = tuple(t.to(args.device).long() for t in batch)
            # print('???',*batch)
            # assert 1==0, "!!!!!"
            with torch.no_grad():
                outputs = model(*batch)
                cls_norm = outputs[1]
                loss = outputs[0]

                count += 1
                avg_cls_norm += float(cls_norm.cpu().data)
                loss_avg += float(loss.cpu().data)
                print(
                    "SEED-Encoder norm: ",
                    cls_norm,
                )
                #print("loss: ",loss)
                #assert 1==0
                #print("optimus norm: ",cls_norm)

            if count == 1024:
                # print('avg_cls_norm: ',float(avg_cls_norm)/count)
                print('avg_cls_sim: ', float(avg_cls_norm) / count)
                print('avg_loss: ', float(loss_avg) / count)
                return

        if args.max_steps > 0 and global_step > args.max_steps:
            train_iterator.close()
            break

    if args.local_rank == -1 or torch.distributed.get_rank() == 0:
        tb_writer.close()

    return global_step, tr_loss / global_step
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num):

    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "test-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query")
    dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia)
    with dev_query_cache_trivia as emb:
        dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training 
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk]
        top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id)

                # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk]
        top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id)

        logger.info("Start searching for query embedding with length %d", len(query_embedding))
        _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk]

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id)

        logger.info("Done generating negative passages, output length %d", len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range: 
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 
                'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
示例#22
0
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        training_query_positive_id,
        dev_query_positive_id,
        latest_step_num):
    config, tokenizer, model = load_model(args, checkpoint_path)

    dataFound = False

    if args.end_output_num == 0 and is_first_worker():
        dataFound, query_embedding, query_embedding2id, dev_query_embedding, dev_query_embedding2id, passage_embedding, passage_embedding2id = load_init_embeddings(args)

    if args.local_rank != -1:
        dist.barrier()

    if not dataFound:
        logger.info("***** inference of dev query *****")
        name_ = None
        if args.dataset == "dl_test_2019":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_1":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-1")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_2":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-2")
            name_ = "test_query_"
        else:
            raise Exception('Dataset should be one of {dl_test_2019, dl_test_2019_1, dl_test_2019_2}!!')
        dev_query_cache = EmbeddingCache(dev_query_collection_path)
        with dev_query_cache as emb:
            dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), name_ + str(latest_step_num) + "_", emb, is_query_inference=True)

        logger.info("***** inference of passages *****")
        passage_collection_path = os.path.join(args.data_dir, "passages")
        passage_cache = EmbeddingCache(passage_collection_path)
        with passage_cache as emb:
            passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
        logger.info("***** Done passage inference *****")

        if args.inference:
            return

        logger.info("***** inference of train query *****")
        train_query_collection_path = os.path.join(args.data_dir, "train-query")
        train_query_cache = EmbeddingCache(train_query_collection_path)
        with train_query_cache as emb:
            query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    else:
        logger.info("***** Found pre-existing embeddings. So not running inference again. *****")

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else (q_start_idx + queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())

        _, dev_I_dist = cpu_index.search(dev_query_embedding, top_k)
        distrib, samplingDist = getSamplingDist(args, dev_I_dist, dev_query_embedding2id, dev_query_positive_id, passage_embedding2id)
        sampling_dist_data = {'distrib': distrib, 'samplingDist': samplingDist}
        dist_output_path = os.path.join(args.output_dir, "dist_" + str(output_num))
        with open(dist_output_path, 'wb') as f:
            pickle.dump(sampling_dist_data, f)

        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding,
            query_embedding2id,
            passage_embedding,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id,
            samplingDist,
            output_num)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        train_query_cache.open()
        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                pos_score = get_BM25_score(query_id, pos_pid, train_query_cache, tokenizer)
                neg_scores = {}
                for neg_pid in query_negative_passage[query_id]:
                    neg_scores[neg_pid] = get_BM25_score(query_id, neg_pid, train_query_cache, tokenizer)
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, str(pos_pid)+":"+str(round(pos_score,3)), ','.join(
                            str(neg_pid)+":"+str(round(neg_scores[neg_pid],3)) for neg_pid in query_negative_passage[query_id])))
        train_query_cache.close()

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev