Beispiel #1
0
def load_data(args):
    train_ann_path = os.path.join(args.data_dir, "train-sec-ann")
    dev_ann_path = os.path.join(args.data_dir, "dev-sec-ann")
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    passage_text = {}
    train_pos_id = []
    train_answers = []
    test_answers = []
    test_pos_id = []
    test_answers_trivia = []

    logger.info("Loading train ann")
    with open(train_ann_path, 'r', encoding='utf8') as f:
        # file format: q_id, positive_pid, answers
        tsvreader = csv.reader(f, delimiter="\t")
        for row in tsvreader:
            train_pos_id.append(row[1])

    logger.info("Loading dev ann")
    with open(dev_ann_path, 'r', encoding='utf8') as f:
        # file format: q_id, positive_pid, answers
        tsvreader = csv.reader(f, delimiter="\t")
        for row in tsvreader:
            test_pos_id.append(row[1])

    logger.info(
        "Finished loading data, pos_id length %d, train answers length %d, test answers length %d",
        len(train_pos_id), len(train_answers), len(test_answers))

    return (train_pos_id, test_pos_id)
Beispiel #2
0
def load_data(args):
    passage_path = os.path.join(args.passage_path, "psgs_w100.tsv")
    test_qa_path = os.path.join(args.test_qa_path, "nq-test.csv")
    trivia_test_qa_path = os.path.join(args.trivia_test_qa_path, "trivia-test.csv")
    train_ann_path = os.path.join(args.data_dir, "train-ann")

    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    passage_text = {}
    train_pos_id = []
    train_answers = []

    test_answers = []

    test_questions = []

    test_answers_trivia = []

    test_questions_trivia = []


    logger.info("Loading train ann")
    with open(train_ann_path, 'r', encoding='utf8') as f:
        # file format: q_id, positive_pid, answers
        tsvreader = csv.reader(f, delimiter="\t")
        for row in tsvreader:
            train_pos_id.append(int(row[1]))
            train_answers.append(eval(row[2]))

    logger.info("Loading test answers")
    with open(test_qa_path, "r", encoding="utf-8") as ifile:
        # file format: question, answers
        reader = csv.reader(ifile, delimiter='\t')
        for row in reader:
            test_answers.append(eval(row[1]))
            test_questions.append(str(row[0]))

    logger.info("Loading trivia test answers")
    with open(trivia_test_qa_path, "r", encoding="utf-8") as ifile:
        # file format: question, answers
        reader = csv.reader(ifile, delimiter='\t')
        for row in reader:
            test_answers_trivia.append(eval(row[1]))
            test_questions_trivia.append(str(row[0]))

    logger.info("Loading passages")
    with open(passage_path, "r", encoding="utf-8") as tsvfile:
        reader = csv.reader(tsvfile, delimiter='\t', )
        # file format: doc_id, doc_text, title
        for row in reader:
            if row[0] != 'id':
                passage_text[pid2offset[int(row[0])]] = (row[1], row[2])
                if args.do_debug and len(passage_text)>10: break


    logger.info("Finished loading data, pos_id length %d, train answers length %d, test answers length %d", len(train_pos_id), len(train_answers), len(test_answers))

    return (passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, test_questions, test_questions_trivia)
Beispiel #3
0
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num):

    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')

    query_embedding, query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="train-query", emb_prefix="query_", is_query_inference=True)
    dev_query_embedding, dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="test-query", emb_prefix="dev_query_", is_query_inference=True)
    dev_query_embedding_trivia, dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-test-query", emb_prefix="trivia_dev_query_", is_query_inference=True)
    real_dev_query_embedding, real_dev_query_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="dev-qas-query", emb_prefix="real-dev_query_", is_query_inference=True)
    real_dev_query_embedding_trivia, real_dev_query_embedding2id_trivia = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="trivia-dev-qas-query", emb_prefix="trivia_real-dev_query_", is_query_inference=True)

    # passage_embedding == None, if args.split_ann_search == True
    passage_embedding, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=model, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb= not args.split_ann_search)
    
    
    if args.gpu_index:
        del model  # leave gpu for faiss
        torch.cuda.empty_cache()
        time.sleep(10)

    if args.local_rank != -1:
        dist.barrier()

    # if None, reloading
    if passage_embedding2id is None and is_first_worker():
        _, passage_embedding2id = inference_or_load_embedding(args=args,logger=logger,model=None, checkpoint_path=checkpoint_path,text_data_prefix="passages", emb_prefix="passage_", is_query_inference=False,load_emb=False)
        logger.info(f"document id size: {passage_embedding2id.shape}")

    if is_first_worker():
        # passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data
        passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia, dev_answers, dev_answers_trivia = preloaded_data

        if not args.split_ann_search:
            dim = passage_embedding.shape[1]
            print('passage embedding shape: ' + str(passage_embedding.shape))
            top_k = args.topk_training 
            faiss.omp_set_num_threads(16)
            cpu_index = faiss.IndexFlatIP(dim)
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(passage_embedding)
            logger.info("***** Done ANN Index *****")
            _, dev_I = index.search(dev_query_embedding, 100) #I: [number of queries, topk]            
            _, dev_I_trivia = index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk]
            logger.info("Start searching for query embedding with length %d", len(query_embedding))
            _, I = index.search(query_embedding, top_k) #I: [number of queries, topk]
        else:
            _, dev_I_trivia, real_dev_D, real_dev_I  = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    top_k_dev=100,
                    top_k=args.topk_training,
                    checkpoint_step=checkpoint_step,
                    dev_query_emb=dev_query_embedding_trivia,
                    train_query_emb=real_dev_query_embedding,
                    emb_prefix="passage_",two_query_set=True,
            )
            dev_D, dev_I, _, I  = document_split_faiss_index(
                    logger=logger,
                    args=args,
                    top_k_dev=100,
                    top_k=args.topk_training,
                    checkpoint_step=checkpoint_step,
                    dev_query_emb=dev_query_embedding,
                    train_query_emb=query_embedding,
                    emb_prefix="passage_",two_query_set=True,
            )
        save_trec_file(
            dev_query_embedding2id,passage_embedding2id,dev_I,dev_D,
            trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-test_" + checkpoint_step + ".trec")),
            topN=100
        )
        save_trec_file(
            real_dev_query_embedding2id,passage_embedding2id,real_dev_I,real_dev_D,
            trec_save_path= os.path.join(os.path.join(args.training_dir,"ann_data", "nq-dev_" + checkpoint_step + ".trec")),
            topN=100
        )
        # measure ANN mrr 
        top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id)
        real_dev_top_k_hits = validate(passage_text, dev_answers, real_dev_I, real_dev_query_embedding2id, passage_embedding2id)
        top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I_trivia, dev_query_embedding2id_trivia, passage_embedding2id)
        query_range_number = I.shape[0]
        json_dump_dict = {
            'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19],
            'dev_top20': real_dev_top_k_hits[19], 'dev_top100': real_dev_top_k_hits[99],  
            'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path, 'n_train_query':query_range_number,
        }
        logger.info(json_dump_dict)

        
        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id)

        logger.info("Done generating negative passages, output length %d", len(query_negative_passage))

        if args.dual_training:
            
            assert args.split_ann_search and args.gpu_index # hard set
            logger.info("***** Begin ANN Index for dual d2q task *****")
            top_k = args.topk_training
            faiss.omp_set_num_threads(args.faiss_omp_num_threads)
            logger.info("***** Faiss: total {} gpus *****".format(faiss.get_num_gpus()))
            cpu_index = faiss.IndexFlatIP(query_embedding.shape[1])
            index = get_gpu_index(cpu_index) if args.gpu_index else cpu_index
            index.add(query_embedding)
            logger.info("***** Done building ANN Index for dual d2q task *****")

            # train_pos_id : a list, idx -> int pid
            train_pos_id_inversed = {}
            for qidx in range(query_embedding2id.shape[0]):
                qid = query_embedding2id[qidx]
                pid = int(train_pos_id[qid])
                if pid not in train_pos_id_inversed:
                    train_pos_id_inversed[pid]=[qid]
                else:
                    train_pos_id_inversed[pid].append(qid)
            
            possitive_training_passage_id = [ train_pos_id[t] for t in query_embedding2id] # 
            # compatible with MaxP
            possitive_training_passage_id_embidx=[]
            possitive_training_passage_id_to_subset_embidx={} # pid to indexs in pos_pas_embs 
            possitive_training_passage_id_emb_counts=0
            for pos_pid in possitive_training_passage_id:
                embidx=np.asarray(np.where(passage_embedding2id==pos_pid)).flatten()
                possitive_training_passage_id_embidx.append(embidx)
                possitive_training_passage_id_to_subset_embidx[int(pos_pid)] = np.asarray(range(possitive_training_passage_id_emb_counts,possitive_training_passage_id_emb_counts+embidx.shape[0]))
                possitive_training_passage_id_emb_counts += embidx.shape[0]
            possitive_training_passage_id_embidx=np.concatenate(possitive_training_passage_id_embidx,axis=0)
            
            if not args.split_ann_search:
                D, I = index.search(passage_embedding[possitive_training_passage_id_embidx], args.topk_training_d2q) 
            else:
                positive_p_embs = loading_possitive_document_embedding(logger,args.output_dir,checkpoint_step,possitive_training_passage_id_embidx,emb_prefix="passage_",)
                assert positive_p_embs.shape[0] == len(possitive_training_passage_id)
                D, I = index.search(positive_p_embs, args.topk_training_d2q) 
                positive_p_embs = None
                del positive_p_embs
            index.reset()
            logger.info("***** Finish ANN searching for dual d2q task, construct  *****")
            passage_negative_queries = GenerateNegativeQueryID(args, passage_text,train_answers, query_embedding2id, passage_embedding2id[possitive_training_passage_id_embidx], closest_ans=I, training_query_positive_id_inversed=train_pos_id_inversed)
            logger.info("***** Done ANN searching for negative queries *****")


        logger.info("***** Construct ANN Triplet *****")
        prefix =  "ann_grouped_training_data_" if args.grouping_ann_data > 0  else "ann_training_data_"
        train_data_output_path = os.path.join(
            args.output_dir, prefix + str(output_num))
        query_range = list(range(query_range_number))
        random.shuffle(query_range)
        if args.grouping_ann_data > 0 :
            with open(train_data_output_path, 'w') as f:
                counting=0
                pos_q_group={}
                pos_d_group={}
                neg_D_group={} # {0:[], 1:[], 2:[]...}
                if args.dual_training:
                    neg_Q_group={}
                for query_idx in query_range: 
                    query_id = query_embedding2id[query_idx]
                    pos_pid = train_pos_id[query_id]
                    
                    pos_q_group[counting]=int(query_id)
                    pos_d_group[counting]=int(pos_pid)

                    neg_D_group[counting]=[int(neg_pid) for neg_pid in query_negative_passage[query_id]]
                    if args.dual_training:
                        neg_Q_group[counting]=[int(neg_qid) for neg_qid in passage_negative_queries[pos_pid]]
                    counting +=1
                    if counting >= args.grouping_ann_data:
                        jsonline_dict={}
                        jsonline_dict["pos_q_group"]=pos_q_group
                        jsonline_dict["pos_d_group"]=pos_d_group
                        jsonline_dict["neg_D_group"]=neg_D_group
                        
                        if args.dual_training:
                            jsonline_dict["neg_Q_group"]=neg_Q_group

                        f.write(f"{json.dumps(jsonline_dict)}\n")

                        counting=0
                        pos_q_group={}
                        pos_d_group={}
                        neg_D_group={} # {0:[], 1:[], 2:[]...}
                        if args.dual_training:
                            neg_Q_group={}
             
        else:
            # not support dualtraining
            with open(train_data_output_path, 'w') as f:
                for query_idx in query_range: 
                    query_id = query_embedding2id[query_idx]
                    # if not query_id in train_pos_id:
                    #     continue
                    pos_pid = train_pos_id[query_id]
                    
                    if not args.dual_training:
                        f.write(
                            "{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id])))
                    else:
                        # if pos_pid not in effective_p_id or pos_pid not in training_query_positive_id_inversed:
                        #     continue
                        f.write(
                            "{}\t{}\t{}\t{}\n".format(
                                query_id, pos_pid,
                                ','.join(
                                    str(neg_pid) for neg_pid in query_negative_passage[query_id]),
                                ','.join(
                                    str(neg_qid) for neg_qid in passage_negative_queries[pos_pid])
                            )
                        )
        ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump(json_dump_dict, f)
Beispiel #4
0
def validate(args, closest_docs, dev_scores, query_embedding2id,
             passage_embedding2id):

    logger.info('Matching answers in top docs...')
    scores = dict()

    count = 0
    total = 0
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")
    passage_path = os.path.join(args.passage_path, "hotpot_wiki.tsv")
    idx2title = dict()
    title2text = dict()
    with open(passage_path, "r", encoding="utf-8") as tsvfile:
        reader = csv.reader(
            tsvfile,
            delimiter='\t',
        )
        # file format: doc_id, doc_text, title
        for row in reader:
            if row[0] != 'id':
                idx2title[int(row[0])] = row[2]
                title2text[row[2]] = row[1]
    with open(args.data_dir + '/hotpot_dev_fullwiki_v1.json', 'r') as fin:
        dataset = json.load(fin)

    type_dict = pickle.load(open(args.data_dir + '/dev_type_results.pkl',
                                 'rb'))
    instances = list()

    dev_ann_path = os.path.join(args.data_dir, "dev-sec-ann")
    test_id = list()
    test_pre_et = list()
    with open(dev_ann_path, 'r', encoding='utf8') as f:
        # file format: q_id, positive_pid, answers
        tsvreader = csv.reader(f, delimiter="\t")
        for row in tsvreader:
            test_id.append(row[0])
            test_pre_et.append(row[1])

    first_hop_ets = pickle.load(
        open(args.data_dir + '/dev_first_hop_pred.pkl', 'rb'))
    pred_dict = dict()
    for query_idx in range(closest_docs.shape[0]):

        query_id = query_embedding2id[query_idx]

        qid = test_id[query_id]
        pre_et = test_pre_et[query_id]

        if qid not in pred_dict:
            pred_dict[qid] = {'chain': list(), 'score': list()}

        pre_score = first_hop_ets[qid]['score'][first_hop_ets[qid]
                                                ['pred'].index(pre_et)]
        all_pred = closest_docs[query_idx]
        scs = dev_scores[query_idx]
        for i in range(len(dev_scores[query_idx])):
            if int(passage_embedding2id[all_pred[i]]) in offset2pid:
                pred_dict[qid]['chain'].append(
                    pre_et + '#######' + normalize(idx2title[offset2pid[int(
                        passage_embedding2id[all_pred[i]])]]))
                pred_dict[qid]['score'].append(float(scs[i]) + pre_score)

    print(len(pred_dict))
    sec_hop_pred = dict()
    for data in dataset:
        qid = data['_id']
        all_pairs = list()
        supp_set = set()
        for supp in data['supporting_facts']:
            title = supp[0]
            supp_set.add(normalize(title))
        #total += 1
        supp_set = list(supp_set)

        if qid in pred_dict:
            doc_scores = pred_dict[qid]['score']
            idxs = sorted(range(len(doc_scores)),
                          key=lambda k: doc_scores[k],
                          reverse=True)
            for idx in idxs[:250]:
                all_pairs.append(pred_dict[qid]['chain'][idx])
        sec_hop_pred[qid] = all_pairs

    pickle.dump(sec_hop_pred,
                open(args.data_dir + '/dev_sec_hop_pred_top250.pkl', 'wb'))
Beispiel #5
0
def generate_new_ann(args):
    #print(test_pos_id.shape)
    #model = None
    model = load_model(args, args.model)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    latest_step_num = args.latest_num
    args.world_size = args.world_size

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-eval-sec")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "dev-sec_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True,
            load_cache=args.load_cache)
    #exit()
    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=False),
            "passage_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=False,
            load_cache=args.load_cache)

    dim = passage_embedding.shape[1]
    #print(dev_query_embedding.shape)
    #print('passage embedding shape: ' + str(passage_embedding.shape))
    print('dev embedding shape: ' + str(dev_query_embedding.shape))

    faiss.omp_set_num_threads(16)
    cpu_index = faiss.IndexFlatIP(dim)
    cpu_index.add(passage_embedding)

    logger.info('Data indexing completed.')
    nums = int(dev_query_embedding.shape[0] / 5000) + 1
    II = list()
    sscores = list()
    for i in range(nums):
        score, idx = cpu_index.search(dev_query_embedding[i * 5000:(i + 1) *
                                                          5000],
                                      args.topk)  #I: [number of queries, topk]
        II.append(idx)
        sscores.append(score)
        logger.info("Split done %d", i)

    dev_I = II[0]
    scores = sscores[0]
    for i in range(1, nums):
        dev_I = np.concatenate((dev_I, II[i]), axis=0)
        scores = np.concatenate((scores, sscores[i]), axis=0)

    validate(args, dev_I, scores, dev_query_embedding2id, passage_embedding2id)
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num):

    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "test-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query")
    dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia)
    with dev_query_cache_trivia as emb:
        dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training 
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk]
        top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id)

                # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk]
        top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id)

        logger.info("Start searching for query embedding with length %d", len(query_embedding))
        _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk]

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id)

        logger.info("Done generating negative passages, output length %d", len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range: 
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 
                'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
Beispiel #7
0
def validate(args, closest_docs, dev_scores, query_embedding2id,
             passage_embedding2id):

    logger.info('Matching answers in top docs...')
    scores = dict()

    count = 0
    total = 0
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")
    passage_path = os.path.join(args.passage_path, "hotpot_wiki.tsv")
    idx2title = dict()
    title2text = dict()
    with open(passage_path, "r", encoding="utf-8") as tsvfile:
        reader = csv.reader(
            tsvfile,
            delimiter='\t',
        )
        # file format: doc_id, doc_text, title
        for row in reader:
            if row[0] != 'id':
                idx2title[int(row[0])] = row[2]
                title2text[row[2]] = row[1]
    with open(args.data_dir + '/hotpot_dev_fullwiki_v1.json', 'r') as fin:
        dataset = json.load(fin)

    type_dict = pickle.load(open(args.data_dir + '/dev_type_results.pkl',
                                 'rb'))
    instances = list()

    first_hop_ets = dict()
    for query_idx in range(closest_docs.shape[0]):
        query_id = query_embedding2id[query_idx]
        all_scores = list()
        doc_ids = list()
        all_pred = closest_docs[query_idx]
        scs = dev_scores[query_idx]
        for i in range(len(dev_scores[query_idx])):
            if int(passage_embedding2id[all_pred[i]]) in offset2pid:
                doc_ids.append(offset2pid[int(
                    passage_embedding2id[all_pred[i]])])
                all_scores.append(float(scs[i]))
        data = dataset[query_id]
        qid = data['_id']
        supp_set = set()

        for supp in data['supporting_facts']:
            title = supp[0]
            supp_set.add(normalize(title))

        total += len(supp_set)
        for ii, d_id in enumerate(doc_ids[:10]):
            title = normalize(idx2title[d_id])
            if title in supp_set:
                count += 1
        first_hop_ets[qid] = {
            'score': all_scores,
            'pred': [normalize(idx2title[idx]) for idx in doc_ids]
        }

        if type_dict[qid] == 'comparison':
            continue

        for et in [normalize(idx2title[idx]) for idx in doc_ids]:
            pre_evidence = ''.join(title2text[et])
            qq = data['question'] + ' ' + '[SEP]' + ' ' + et.replace(
                '_', ' ') + ' ' + '[SEP]' + ' ' + pre_evidence
            instances.append({
                'dataset': 'hotpot_dev_sec',
                'question': qq,
                'qid': qid,
                'answers': list(),
                'first_hop_cts': [et]
            })

    with open(args.data_dir + '/dev_sec_hop_data.json', 'w',
              encoding='utf-8') as f:
        json.dump(instances, f, indent=2)
    pickle.dump(first_hop_ets,
                open(args.data_dir + '/dev_first_hop_pred.pkl', 'wb'))
    logger.info("first hop coverage %f", count / total)
Beispiel #8
0
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data,
                     latest_step_num):
    #passage_text, train_pos_id, train_answers, test_answers, test_pos_id = preloaded_data
    #print(test_pos_id.shape)
    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir,
                                               "train-sec-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "query_sec_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-sec-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "dev_sec_query_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=False),
            "passage_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=False,
            load_cache=False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        train_pos_id, test_pos_id = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        #num_q = passage_embedding.shape[0]

        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)

        logger.info('Data indexing completed.')

        logger.info("Start searching for query embedding with length %d",
                    len(query_embedding))

        II = list()
        for i in range(15):
            _, idx = cpu_index.search(query_embedding[i * 5000:(i + 1) * 5000],
                                      top_k)  #I: [number of queries, topk]
            II.append(idx)
            logger.info("Split done %d", i)

        I = II[0]
        for i in range(1, 15):
            I = np.concatenate((I, II[i]), axis=0)

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        #logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = dict()
        for query_idx in range(I.shape[0]):
            query_id = query_embedding2id[query_idx]
            doc_ids = list()

            doc_ids = [passage_embedding2id[pidx] for pidx in I[query_idx]]

            neg_docs = list()
            for doc_id in doc_ids:
                pos_id = [
                    int(p_id) for p_id in train_pos_id[query_id].split(',')
                ]

                if doc_id in pos_id:
                    continue
                if doc_id in neg_docs:
                    continue
                neg_docs.append(doc_id)

            query_negative_passage[query_id] = neg_docs

        logger.info("Done generating negative passages, output length %d",
                    len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(
                    query_id, pos_pid, ','.join(
                        str(neg_pid)
                        for neg_pid in query_negative_passage[query_id])))

        _, dev_I = cpu_index.search(dev_query_embedding,
                                    10)  #I: [number of queries, topk]

        top_k_hits = validate(test_pos_id, dev_I, dev_query_embedding2id,
                              passage_embedding2id)