Esempio n. 1
0
def inference_or_load_embedding(args,logger,model,checkpoint_path,text_data_prefix, emb_prefix, is_query_inference=True,checkonly=False,load_emb=True):
    # logging.info(f"checkpoint_path {checkpoint_path}")
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')
    emb_file_pattern = os.path.join(args.output_dir,f'{emb_prefix}{checkpoint_step}__emb_p__data_obj_*.pb')
    emb_file_lists = glob.glob(emb_file_pattern)
    emb_file_lists = sorted(emb_file_lists, key=lambda name: int(name.split('_')[-1].replace('.pb',''))) # sort with split num
    logger.info(f"pattern {emb_file_pattern}\n file lists: {emb_file_lists}")
    embedding,embedding2id = [None,None]
    if len(emb_file_lists) > 0:
        if is_first_worker():
            logger.info(f"***** found existing embedding files {emb_file_pattern}, loading... *****")
            if checkonly:
                logger.info("check embedding files only, not loading")
                return embedding,embedding2id
            embedding = []
            embedding2id = []
            for emb_file in emb_file_lists:
                if load_emb:
                    with open(emb_file,'rb') as handle:
                        embedding.append(pickle.load(handle))
                embid_file = emb_file.replace('emb_p','embid_p')
                with open(embid_file,'rb') as handle:
                    embedding2id.append(pickle.load(handle))
            if (load_emb and not embedding) or (not embedding2id):
                logger.error("No data found for checkpoint: ",emb_file_pattern)
            if load_emb:
                embedding = np.concatenate(embedding, axis=0)
            embedding2id = np.concatenate(embedding2id, axis=0)
        # return embedding,embedding2id
    # else:
        if args.local_rank != -1:
            dist.barrier() # if multi-processing
    else:
        logger.info(f"***** inference of {text_data_prefix} *****")
        query_collection_path = os.path.join(args.data_dir, text_data_prefix)
        query_cache = EmbeddingCache(query_collection_path)
        with query_cache as emb:
            embedding,embedding2id = StreamInferenceDoc(args, model, 
                GetProcessingFn(args, query=is_query_inference),
                emb_prefix + str(checkpoint_step) + "_", emb,
                is_query_inference=is_query_inference)
    return embedding,embedding2id
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        latest_step_num,
        training_query_positive_id=None,
        dev_query_positive_id=None,
        ):
    config, tokenizer, model = load_model(args, checkpoint_path)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    if args.inference:
        return

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
    logger.info("***** Done passage inference *****")

    

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(
            args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (
            effective_idx == (
                chunk_factor -
                1)) else (
            q_start_idx +
            queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())
        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding2id,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, pos_pid, ','.join(
                            str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev
Esempio n. 3
0
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        training_query_positive_id,
        dev_query_positive_id,
        latest_step_num):
    config, tokenizer, model = load_model(args, checkpoint_path)

    dataFound = False

    if args.end_output_num == 0 and is_first_worker():
        dataFound, query_embedding, query_embedding2id, dev_query_embedding, dev_query_embedding2id, passage_embedding, passage_embedding2id = load_init_embeddings(args)

    if args.local_rank != -1:
        dist.barrier()

    if not dataFound:
        logger.info("***** inference of dev query *****")
        name_ = None
        if args.dataset == "dl_test_2019":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_1":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-1")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_2":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-2")
            name_ = "test_query_"
        else:
            raise Exception('Dataset should be one of {dl_test_2019, dl_test_2019_1, dl_test_2019_2}!!')
        dev_query_cache = EmbeddingCache(dev_query_collection_path)
        with dev_query_cache as emb:
            dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), name_ + str(latest_step_num) + "_", emb, is_query_inference=True)

        logger.info("***** inference of passages *****")
        passage_collection_path = os.path.join(args.data_dir, "passages")
        passage_cache = EmbeddingCache(passage_collection_path)
        with passage_cache as emb:
            passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
        logger.info("***** Done passage inference *****")

        if args.inference:
            return

        logger.info("***** inference of train query *****")
        train_query_collection_path = os.path.join(args.data_dir, "train-query")
        train_query_cache = EmbeddingCache(train_query_collection_path)
        with train_query_cache as emb:
            query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    else:
        logger.info("***** Found pre-existing embeddings. So not running inference again. *****")

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else (q_start_idx + queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())

        _, dev_I_dist = cpu_index.search(dev_query_embedding, top_k)
        distrib, samplingDist = getSamplingDist(args, dev_I_dist, dev_query_embedding2id, dev_query_positive_id, passage_embedding2id)
        sampling_dist_data = {'distrib': distrib, 'samplingDist': samplingDist}
        dist_output_path = os.path.join(args.output_dir, "dist_" + str(output_num))
        with open(dist_output_path, 'wb') as f:
            pickle.dump(sampling_dist_data, f)

        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding,
            query_embedding2id,
            passage_embedding,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id,
            samplingDist,
            output_num)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        train_query_cache.open()
        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                pos_score = get_BM25_score(query_id, pos_pid, train_query_cache, tokenizer)
                neg_scores = {}
                for neg_pid in query_negative_passage[query_id]:
                    neg_scores[neg_pid] = get_BM25_score(query_id, neg_pid, train_query_cache, tokenizer)
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, str(pos_pid)+":"+str(round(pos_score,3)), ','.join(
                            str(neg_pid)+":"+str(round(neg_scores[neg_pid],3)) for neg_pid in query_negative_passage[query_id])))
        train_query_cache.close()

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev