Beispiel #1
0
def main():
    args = get_arguments()
    set_env(args)
    tokenizer, model = load_model(args)

    query_collection_path = os.path.join(args.data_dir, "train-query")
    query_cache = EmbeddingCache(query_collection_path)
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)

    with query_cache, passage_cache:
        global_step = train(args, model, tokenizer, query_cache, passage_cache)
        logger.info(" global_step = %s", global_step)

    save_checkpoint(args, model, tokenizer)
Beispiel #2
0
def main():
    args = get_arguments()
    set_env(args)
    model = load_model(args)

    query_collection_path = os.path.join(args.data_dir, "train-query")
    query_cache = EmbeddingCache(query_collection_path)
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)

    with query_cache, passage_cache:
        global_step = train(args, model, query_cache, passage_cache)
        logger.info(" global_step = %s", global_step)

    if args.local_rank != -1:
        dist.barrier()
Beispiel #3
0
def preprocess(args):

    pid2offset = {}
    in_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if False:
        print("preprocessed data already exist, exit preprocessing")
        return
    else:
        out_line_count = 0

        print('start passage file split processing')
        print(args.model_type)
        #multi_file_process(args, 32, in_passage_path, out_passage_path, PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        in_passage_path, 53, 8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1
                print(out_line_count)

        print("Total lines written: " + str(out_line_count))
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        write_mapping(args, pid2offset, "pid2offset")
    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[pid2offset[1]])

    write_qas_query(args, pid2offset, "hotpot_train_first_hop.json",
                    "train-query", "train-ann")
    write_qas_query(args, pid2offset, "hotpot_dev_first_hop.json", "dev-query",
                    "dev-ann")
    write_qas_query(args, pid2offset, "hotpot_train_sec_hop.json",
                    "train-sec-query", "train-sec-ann")
    write_qas_query(args, pid2offset, "hotpot_dev_sec_hop.json",
                    "dev-sec-query", "dev-sec-ann")
Beispiel #4
0
def write_query_norel(args, prefix, query_file, out_query_file):

    query_collection_path = os.path.join(
        args.data_dir,
        query_file,
    )

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    qid2offset = {}

    print('start query file split processing')
    multi_file_process(args, args.n_split_process, query_collection_path,
                       out_query_path, QueryPreprocessingFn)

    print('start merging splits')

    idx = 0
    with open(out_query_path, 'wb') as f:
        for record in numbered_byte_file_generator(
                out_query_path, args.n_split_process,
                8 + 4 + args.max_query_length * 4):
            q_id = int.from_bytes(record[:8], 'big')
            f.write(record[8:])
            qid2offset[q_id] = idx
            idx += 1
            if idx < 3:
                print(str(idx) + " " + str(q_id))
    for i in range(args.n_split_process):
        os.remove('{}_split{}'.format(out_query_path,
                                      i))  # delete intermediate files

    qid2offset_path = os.path.join(
        args.out_data_dir,
        prefix + "_qid2offset.pickle",
    )
    with open(qid2offset_path, 'wb') as handle:
        pickle.dump(qid2offset, handle, protocol=4)
    print("done saving qid2offset")

    print("Total lines written: " + str(idx))
    meta = {
        'type': 'int32',
        'total_number': idx,
        'embedding_size': args.max_query_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])
Beispiel #5
0
def evaluate_dev(args, model, passage_cache, source=""):

    dev_query_collection_path = os.path.join(args.data_dir,
                                             "dev-query{}".format(source))
    dev_query_cache = EmbeddingCache(dev_query_collection_path)

    logger.info('NLL validation ...')

    model.eval()

    log_result_step = 100
    batches = 0
    total_loss = 0.0
    total_correct_predictions = 0

    with dev_query_cache:
        dev_data_path = os.path.join(args.data_dir,
                                     "dev-data{}".format(source))
        with open(dev_data_path, 'r') as f:
            dev_data = f.readlines()
        dev_dataset = StreamingDataset(
            dev_data,
            GetTrainingDataProcessingFn(args,
                                        dev_query_cache,
                                        passage_cache,
                                        shuffle=False))
        dev_dataloader = DataLoader(dev_dataset,
                                    batch_size=args.train_batch_size * 2)

        for i, batch in enumerate(dev_dataloader):
            loss, correct_cnt = do_biencoder_fwd_pass(args, model, batch)
            loss.backward()  # get CUDA oom without this
            model.zero_grad()
            total_loss += loss.item()
            total_correct_predictions += correct_cnt
            batches += 1
            if (i + 1) % log_result_step == 0:
                logger.info('Eval step: %d , loss=%f ', i, loss.item())

    total_loss = total_loss / batches
    total_samples = batches * args.train_batch_size * torch.distributed.get_world_size(
    )
    correct_ratio = float(total_correct_predictions / total_samples)
    logger.info(
        'NLL Validation: loss = %f. correct prediction ratio  %d/%d ~  %f',
        total_loss, total_correct_predictions, total_samples, correct_ratio)

    model.train()
    return total_loss, correct_ratio
Beispiel #6
0
def inference_or_load_embedding(args,logger,model,checkpoint_path,text_data_prefix, emb_prefix, is_query_inference=True,checkonly=False,load_emb=True):
    # logging.info(f"checkpoint_path {checkpoint_path}")
    checkpoint_step = checkpoint_path.split('-')[-1].replace('/','')
    emb_file_pattern = os.path.join(args.output_dir,f'{emb_prefix}{checkpoint_step}__emb_p__data_obj_*.pb')
    emb_file_lists = glob.glob(emb_file_pattern)
    emb_file_lists = sorted(emb_file_lists, key=lambda name: int(name.split('_')[-1].replace('.pb',''))) # sort with split num
    logger.info(f"pattern {emb_file_pattern}\n file lists: {emb_file_lists}")
    embedding,embedding2id = [None,None]
    if len(emb_file_lists) > 0:
        if is_first_worker():
            logger.info(f"***** found existing embedding files {emb_file_pattern}, loading... *****")
            if checkonly:
                logger.info("check embedding files only, not loading")
                return embedding,embedding2id
            embedding = []
            embedding2id = []
            for emb_file in emb_file_lists:
                if load_emb:
                    with open(emb_file,'rb') as handle:
                        embedding.append(pickle.load(handle))
                embid_file = emb_file.replace('emb_p','embid_p')
                with open(embid_file,'rb') as handle:
                    embedding2id.append(pickle.load(handle))
            if (load_emb and not embedding) or (not embedding2id):
                logger.error("No data found for checkpoint: ",emb_file_pattern)
            if load_emb:
                embedding = np.concatenate(embedding, axis=0)
            embedding2id = np.concatenate(embedding2id, axis=0)
        # return embedding,embedding2id
    # else:
    #    if args.local_rank != -1:
    #        dist.barrier() # if multi-processing
    else:
        logger.info(f"***** inference of {text_data_prefix} *****")
        query_collection_path = os.path.join(args.data_dir, text_data_prefix)
        query_cache = EmbeddingCache(query_collection_path)
        with query_cache as emb:
            embedding,embedding2id = StreamInferenceDoc(args, model, 
                GetProcessingFn(args, query=is_query_inference),
                emb_prefix + str(checkpoint_step) + "_", emb,
                is_query_inference=is_query_inference)
    return embedding,embedding2id
Beispiel #7
0
def write_query_rel(args,
                    pid2offset,
                    query_file,
                    out_query_file,
                    out_ann_file,
                    out_train_file,
                    passage_id_name="passage_id"):

    print("Writing query files " + str(out_query_file) + " and " +
          str(out_ann_file))

    query_path = os.path.join(
        args.question_dir,
        query_file,
    )

    with open(query_path, 'r', encoding="utf-8") as f:
        data = json.load(f)
        print('Aggregated data size: {}'.format(len(data)))

    data = [r for r in data if len(r['positive_ctxs']) > 0]
    print('Total cleaned data size: {}'.format(len(data)))
    data = [r for r in data if len(r['hard_negative_ctxs']) > 0]
    print('Total cleaned data size: {}'.format(len(data)))

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    out_ann_file = os.path.join(
        args.out_data_dir,
        out_ann_file,
    )

    out_training_path = os.path.join(
        args.out_data_dir,
        out_train_file,
    )

    qid = 0

    configObj = MSMarcoConfigDict[args.model_type]
    tokenizer = configObj.tokenizer_class.from_pretrained(
        args.model_name_or_path,
        do_lower_case=True,
        cache_dir=None,
    )

    with open(out_query_path, "wb") as out_query, \
            open(out_ann_file, "w", encoding='utf-8') as out_ann, \
            open(out_training_path, "w", encoding='utf-8') as out_training:
        for sample in data:
            positive_ctxs = sample['positive_ctxs']
            neg_ctxs = sample['hard_negative_ctxs']
            question = normalize_question(sample['question'])
            first_pos_pid = pid2offset[int(positive_ctxs[0][passage_id_name])]
            neg_pids = [
                str(pid2offset[int(neg_ctx[passage_id_name])])
                for neg_ctx in neg_ctxs
            ]
            out_ann.write("{}\t{}\t{}\n".format(qid, first_pos_pid,
                                                sample["answers"]))
            out_training.write("{}\t{}\t{}\n".format(qid, first_pos_pid,
                                                     ','.join(neg_pids)))
            out_query.write(
                QueryPreprocessingFn(args, qid, question, tokenizer))
            qid += 1

    print("Total lines written: " + str(qid))
    meta = {
        'type': 'int32',
        'total_number': qid,
        'embedding_size': args.max_seq_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])
Beispiel #8
0
def preprocess(args):

    pid2offset = {}
    in_passage_path = os.path.join(
        args.wiki_dir,
        "psgs_w100.tsv",
    )
    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return
    else:
        out_line_count = 0

        print('start passage file split processing')
        multi_file_process(args, 32, in_passage_path, out_passage_path,
                           PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        out_passage_path, 32,
                        8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1

        print("Total lines written: " + str(out_line_count))
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        write_mapping(args, pid2offset, "pid2offset")

    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[pid2offset[1]])

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "nq-train.json", "train-query",
                        "train-ann", "train-data")
    elif args.data_type == 1:
        write_query_rel(args, pid2offset, "trivia-train.json", "train-query",
                        "train-ann", "train-data", "psg_id")
    else:
        # use both training dataset and merge them
        write_query_rel(args, pid2offset, "nq-train.json", "train-query-nq",
                        "train-ann-nq", "train-data-nq")
        write_query_rel(args, pid2offset, "trivia-train.json",
                        "train-query-trivia", "train-ann-trivia",
                        "train-data-trivia", "psg_id")

        with open(args.out_data_dir + "train-query-nq", "rb") as nq_query, \
                open(args.out_data_dir + "train-query-trivia", "rb") as trivia_query, \
                open(args.out_data_dir + "train-query", "wb") as out_query:
            out_query.write(nq_query.read())
            out_query.write(trivia_query.read())

        with open(args.out_data_dir + "train-query-nq_meta", "r", encoding='utf-8') as nq_query, \
                open(args.out_data_dir + "train-query-trivia_meta", "r", encoding='utf-8') as trivia_query, \
                open(args.out_data_dir + "train-query_meta", "w", encoding='utf-8') as out_query:
            a = json.load(nq_query)
            b = json.load(trivia_query)
            meta = {
                'type': 'int32',
                'total_number': a['total_number'] + b['total_number'],
                'embedding_size': args.max_seq_length
            }
            json.dump(meta, out_query)

        embedding_cache = EmbeddingCache(args.out_data_dir + "train-query")
        print("First line after merge")
        with embedding_cache as emb:
            print(emb[58812])

        with open(args.out_data_dir + "train-ann-nq", "r", encoding='utf-8') as nq_ann, \
                open(args.out_data_dir + "train-ann-trivia", "r", encoding='utf-8') as trivia_ann, \
                open(args.out_data_dir + "train-ann", "w", encoding='utf-8') as out_ann:
            out_ann.writelines(nq_ann.readlines())
            out_ann.writelines(trivia_ann.readlines())

    write_query_rel(args, pid2offset, "nq-dev.json", "dev-query", "dev-ann",
                    "dev-data")
    write_query_rel(args, pid2offset, "trivia-dev.json", "dev-query-trivia",
                    "dev-ann-trivia", "dev-data-trivia", "psg_id")
    write_qas_query(args, "nq-test.csv", "test-query")
    write_qas_query(args, "trivia-test.csv", "trivia-test-query")
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        latest_step_num,
        training_query_positive_id=None,
        dev_query_positive_id=None,
        ):
    config, tokenizer, model = load_model(args, checkpoint_path)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "dev_query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    if args.inference:
        return

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
    logger.info("***** Done passage inference *****")

    

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
            args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(
            args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (
            effective_idx == (
                chunk_factor -
                1)) else (
            q_start_idx +
            queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())
        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding2id,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, pos_pid, ','.join(
                            str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev
Beispiel #10
0
def write_query_rel(args, pid2offset, query_file, positive_id_file,
                    out_query_file, out_id_file):

    # args,
    # pid2offset,
    # "msmarco-test2019-queries.tsv",
    # "2019qrels-docs.txt",
    # "dev-query",
    # "dev-qrel.tsv"
    print("Writing query files " + str(out_query_file) + " and " +
          str(out_id_file))
    query_positive_id = set()

    query_positive_id_path = os.path.join(
        args.data_dir,
        positive_id_file,
    )

    print("Loading query_2_pos_docid")
    with gzip.open(query_positive_id_path, 'rt',
                   encoding='utf8') if positive_id_file[-2:] == "gz" else open(
                       query_positive_id_path, 'r', encoding='utf8') as f:
        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        for [topicid, _, docid, rel] in tsvreader:
            query_positive_id.add(int(topicid))

    query_collection_path = os.path.join(
        args.data_dir,
        query_file,
    )

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    qid2offset = {}

    print('start query file split processing')
    multi_file_process(args, 32, query_collection_path, out_query_path,
                       QueryPreprocessingFn)

    print('start merging splits')

    idx = 0
    with open(out_query_path, 'wb') as f:
        for record in numbered_byte_file_generator(
                out_query_path, 32, 8 + 4 + args.max_query_length * 4):
            q_id = int.from_bytes(record[:8], 'big')
            if q_id not in query_positive_id:
                # exclude the query as it is not in label set
                continue
            f.write(record[8:])
            qid2offset[q_id] = idx
            idx += 1
            if idx < 3:
                print(str(idx) + " " + str(q_id))

    qid2offset_path = os.path.join(
        args.out_data_dir,
        "qid2offset.pickle",
    )
    with open(qid2offset_path, 'wb') as handle:
        pickle.dump(qid2offset, handle, protocol=4)
    print("done saving qid2offset")

    print("Total lines written: " + str(idx))
    meta = {
        'type': 'int32',
        'total_number': idx,
        'embedding_size': args.max_query_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    out_id_path = os.path.join(
        args.out_data_dir,
        out_id_file,
    )

    print("Writing qrels")
    with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f, \
            open(out_id_path, "w", encoding='utf-8') as out_id:

        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        out_line_count = 0
        for [topicid, _, docid, rel] in tsvreader:
            topicid = int(topicid)
            if args.data_type == 0:
                docid = int(docid[1:])
            else:
                docid = int(docid)
            out_id.write(
                str(qid2offset[topicid]) + "\t" + str(pid2offset[docid]) +
                "\t" + rel + "\n")
            out_line_count += 1
        print("Total lines written: " + str(out_line_count))
Beispiel #11
0
def preprocess(args):

    pid2offset = {}
    if args.data_type == 0:
        in_passage_path = os.path.join(
            args.data_dir,
            "msmarco-docs.tsv",
        )
    else:
        in_passage_path = os.path.join(
            args.data_dir,
            "collection.tsv",
        )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return

    out_line_count = 0

    print('start passage file split processing')
    multi_file_process(args, 32, in_passage_path, out_passage_path,
                       PassagePreprocessingFn)

    print('start merging splits')
    with open(out_passage_path, 'wb') as f:
        for idx, record in enumerate(
                numbered_byte_file_generator(out_passage_path, 32,
                                             8 + 4 + args.max_seq_length * 4)):
            p_id = int.from_bytes(record[:8], 'big')
            f.write(record[8:])
            pid2offset[p_id] = idx
            if idx < 3:
                print(str(idx) + " " + str(p_id))
            out_line_count += 1

    print("Total lines written: " + str(out_line_count))
    meta = {
        'type': 'int32',
        'total_number': out_line_count,
        'embedding_size': args.max_seq_length
    }
    with open(out_passage_path + "_meta", 'w') as f:
        json.dump(meta, f)
    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    pid2offset_path = os.path.join(
        args.out_data_dir,
        "pid2offset.pickle",
    )
    with open(pid2offset_path, 'wb') as handle:
        pickle.dump(pid2offset, handle, protocol=4)
    print("done saving pid2offset")

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "msmarco-doctrain-queries.tsv",
                        "msmarco-doctrain-qrels.tsv", "train-query",
                        "train-qrel.tsv")
        write_query_rel(args, pid2offset, "msmarco-test2019-queries.tsv",
                        "2019qrels-docs.txt", "dev-query", "dev-qrel.tsv")
    else:
        write_query_rel(args, pid2offset, "queries.train.tsv",
                        "qrels.train.tsv", "train-query", "train-qrel.tsv")
        write_query_rel(args, pid2offset, "queries.dev.small.tsv",
                        "qrels.dev.small.tsv", "dev-query", "dev-qrel.tsv")
Beispiel #12
0
def generate_new_ann(args):
    #print(test_pos_id.shape)
    #model = None
    model = load_model(args, args.model)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    latest_step_num = args.latest_num
    args.world_size = args.world_size

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-eval-sec")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "dev-sec_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True,
            load_cache=args.load_cache)
    #exit()
    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=False),
            "passage_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=False,
            load_cache=args.load_cache)

    dim = passage_embedding.shape[1]
    #print(dev_query_embedding.shape)
    #print('passage embedding shape: ' + str(passage_embedding.shape))
    print('dev embedding shape: ' + str(dev_query_embedding.shape))

    faiss.omp_set_num_threads(16)
    cpu_index = faiss.IndexFlatIP(dim)
    cpu_index.add(passage_embedding)

    logger.info('Data indexing completed.')
    nums = int(dev_query_embedding.shape[0] / 5000) + 1
    II = list()
    sscores = list()
    for i in range(nums):
        score, idx = cpu_index.search(dev_query_embedding[i * 5000:(i + 1) *
                                                          5000],
                                      args.topk)  #I: [number of queries, topk]
        II.append(idx)
        sscores.append(score)
        logger.info("Split done %d", i)

    dev_I = II[0]
    scores = sscores[0]
    for i in range(1, nums):
        dev_I = np.concatenate((dev_I, II[i]), axis=0)
        scores = np.concatenate((scores, sscores[i]), axis=0)

    validate(args, dev_I, scores, dev_query_embedding2id, passage_embedding2id)
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data, latest_step_num):

    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir, "train-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "query_" + str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "test-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    dev_query_collection_path_trivia = os.path.join(args.data_dir, "trivia-test-query")
    dev_query_cache_trivia = EmbeddingCache(dev_query_collection_path_trivia)
    with dev_query_cache_trivia as emb:
        dev_query_embedding_trivia, dev_query_embedding2id_trivia = StreamInferenceDoc(args, model, GetProcessingFn(args, query=True), "dev_query_"+ str(latest_step_num)+"_", emb, is_query_inference = True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(args, query=False), "passage_"+ str(latest_step_num)+"_", emb, is_query_inference = False, load_cache = False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        passage_text, train_pos_id, train_answers, test_answers, test_answers_trivia = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training 
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding, 100) #I: [number of queries, topk]
        top_k_hits = validate(passage_text, test_answers, dev_I, dev_query_embedding2id, passage_embedding2id)

                # measure ANN mrr 
        _, dev_I = cpu_index.search(dev_query_embedding_trivia, 100) #I: [number of queries, topk]
        top_k_hits_trivia = validate(passage_text, test_answers_trivia, dev_I, dev_query_embedding2id_trivia, passage_embedding2id)

        logger.info("Start searching for query embedding with length %d", len(query_embedding))
        _, I = cpu_index.search(query_embedding, top_k) #I: [number of queries, topk]

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = GenerateNegativePassaageID(args, passage_text, train_answers, query_embedding2id, passage_embedding2id, I, train_pos_id)

        logger.info("Done generating negative passages, output length %d", len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range: 
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(query_id, pos_pid, ','.join(str(neg_pid) for neg_pid in query_negative_passage[query_id])))

        ndcg_output_path = os.path.join(args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'top20': top_k_hits[19], 'top100': top_k_hits[99], 'top20_trivia': top_k_hits_trivia[19], 
                'top100_trivia': top_k_hits_trivia[99], 'checkpoint': checkpoint_path}, f)
Beispiel #14
0
def generate_new_ann(
        args,
        output_num,
        checkpoint_path,
        training_query_positive_id,
        dev_query_positive_id,
        latest_step_num):
    config, tokenizer, model = load_model(args, checkpoint_path)

    dataFound = False

    if args.end_output_num == 0 and is_first_worker():
        dataFound, query_embedding, query_embedding2id, dev_query_embedding, dev_query_embedding2id, passage_embedding, passage_embedding2id = load_init_embeddings(args)

    if args.local_rank != -1:
        dist.barrier()

    if not dataFound:
        logger.info("***** inference of dev query *****")
        name_ = None
        if args.dataset == "dl_test_2019":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_1":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-1")
            name_ = "test_query_"
        elif args.dataset == "dl_test_2019_2":
            dev_query_collection_path = os.path.join(args.data_dir, "test-query-2")
            name_ = "test_query_"
        else:
            raise Exception('Dataset should be one of {dl_test_2019, dl_test_2019_1, dl_test_2019_2}!!')
        dev_query_cache = EmbeddingCache(dev_query_collection_path)
        with dev_query_cache as emb:
            dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), name_ + str(latest_step_num) + "_", emb, is_query_inference=True)

        logger.info("***** inference of passages *****")
        passage_collection_path = os.path.join(args.data_dir, "passages")
        passage_cache = EmbeddingCache(passage_collection_path)
        with passage_cache as emb:
            passage_embedding, passage_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=False), "passage_" + str(latest_step_num) + "_", emb, is_query_inference=False)
        logger.info("***** Done passage inference *****")

        if args.inference:
            return

        logger.info("***** inference of train query *****")
        train_query_collection_path = os.path.join(args.data_dir, "train-query")
        train_query_cache = EmbeddingCache(train_query_collection_path)
        with train_query_cache as emb:
            query_embedding, query_embedding2id = StreamInferenceDoc(args, model, GetProcessingFn(
                args, query=True), "query_" + str(latest_step_num) + "_", emb, is_query_inference=True)
    else:
        logger.info("***** Found pre-existing embeddings. So not running inference again. *****")

    if is_first_worker():
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)
        logger.info("***** Done ANN Index *****")

        # measure ANN mrr
        # I: [number of queries, topk]
        _, dev_I = cpu_index.search(dev_query_embedding, 100)
        dev_ndcg, num_queries_dev = EvalDevQuery(args, dev_query_embedding2id, passage_embedding2id, dev_query_positive_id, dev_I)

        # Construct new traing set ==================================
        chunk_factor = args.ann_chunk_factor
        effective_idx = output_num % chunk_factor

        if chunk_factor <= 0:
            chunk_factor = 1
        num_queries = len(query_embedding)
        queries_per_chunk = num_queries // chunk_factor
        q_start_idx = queries_per_chunk * effective_idx
        q_end_idx = num_queries if (effective_idx == (chunk_factor - 1)) else (q_start_idx + queries_per_chunk)
        query_embedding = query_embedding[q_start_idx:q_end_idx]
        query_embedding2id = query_embedding2id[q_start_idx:q_end_idx]

        logger.info(
            "Chunked {} query from {}".format(
                len(query_embedding),
                num_queries))
        # I: [number of queries, topk]
        _, I = cpu_index.search(query_embedding, top_k)

        effective_q_id = set(query_embedding2id.flatten())

        _, dev_I_dist = cpu_index.search(dev_query_embedding, top_k)
        distrib, samplingDist = getSamplingDist(args, dev_I_dist, dev_query_embedding2id, dev_query_positive_id, passage_embedding2id)
        sampling_dist_data = {'distrib': distrib, 'samplingDist': samplingDist}
        dist_output_path = os.path.join(args.output_dir, "dist_" + str(output_num))
        with open(dist_output_path, 'wb') as f:
            pickle.dump(sampling_dist_data, f)

        query_negative_passage = GenerateNegativePassaageID(
            args,
            query_embedding,
            query_embedding2id,
            passage_embedding,
            passage_embedding2id,
            training_query_positive_id,
            I,
            effective_q_id,
            samplingDist,
            output_num)

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        train_query_cache.open()
        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                if query_id not in effective_q_id or query_id not in training_query_positive_id:
                    continue
                pos_pid = training_query_positive_id[query_id]
                pos_score = get_BM25_score(query_id, pos_pid, train_query_cache, tokenizer)
                neg_scores = {}
                for neg_pid in query_negative_passage[query_id]:
                    neg_scores[neg_pid] = get_BM25_score(query_id, neg_pid, train_query_cache, tokenizer)
                f.write(
                    "{}\t{}\t{}\n".format(
                        query_id, str(pos_pid)+":"+str(round(pos_score,3)), ','.join(
                            str(neg_pid)+":"+str(round(neg_scores[neg_pid],3)) for neg_pid in query_negative_passage[query_id])))
        train_query_cache.close()

        ndcg_output_path = os.path.join(
            args.output_dir, "ann_ndcg_" + str(output_num))
        with open(ndcg_output_path, 'w') as f:
            json.dump({'ndcg': dev_ndcg, 'checkpoint': checkpoint_path}, f)

        return dev_ndcg, num_queries_dev
Beispiel #15
0
def preprocess(args):

    pid2offset = {}

    in_passage_path = os.path.join(
        args.data_dir,
        args.doc_collection_tsv,
    )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )
    pid2offset_path = os.path.join(
        args.out_data_dir,
        "pid2offset.pickle",
    )
    if os.path.isfile(out_passage_path):
        print("preprocessed passage data already exist, skip")
        with open(pid2offset_path, 'rb') as handle:
            pid2offset = pickle.load(handle)
    else:
        out_line_count = 0

        print('start passage file split processing')
        multi_file_process(args, args.n_split_process, in_passage_path,
                           out_passage_path, PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        out_passage_path, args.n_split_process,
                        8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1
        print("Total lines written: " + str(out_line_count))
        for i in range(args.n_split_process):
            os.remove('{}_split{}'.format(out_passage_path,
                                          i))  # delete intermediate files
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        embedding_cache = EmbeddingCache(out_passage_path)
        print("First line")
        with embedding_cache as emb:
            print(emb[0])

        with open(pid2offset_path, 'wb') as handle:
            pickle.dump(pid2offset, handle, protocol=4)
        print("done saving pid2offset")

    if args.qrel_tsv is not None:
        write_query_rel(args, pid2offset, args.query_collection_tsv,
                        args.qrel_tsv, "{}-query".format(args.save_prefix),
                        "{}-qrel.tsv".format(args.save_prefix))
    else:
        write_query_norel(
            args,
            args.save_prefix,
            args.query_collection_tsv,
            "{}-query".format(args.save_prefix),
        )
Beispiel #16
0
def generate_new_ann(args, output_num, checkpoint_path, preloaded_data,
                     latest_step_num):
    #passage_text, train_pos_id, train_answers, test_answers, test_pos_id = preloaded_data
    #print(test_pos_id.shape)
    model = load_model(args, checkpoint_path)
    pid2offset, offset2pid = load_mapping(args.data_dir, "pid2offset")

    logger.info("***** inference of train query *****")
    train_query_collection_path = os.path.join(args.data_dir,
                                               "train-sec-query")
    train_query_cache = EmbeddingCache(train_query_collection_path)
    with train_query_cache as emb:
        query_embedding, query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "query_sec_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of dev query *****")
    dev_query_collection_path = os.path.join(args.data_dir, "dev-sec-query")
    dev_query_cache = EmbeddingCache(dev_query_collection_path)
    with dev_query_cache as emb:
        dev_query_embedding, dev_query_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=True),
            "dev_sec_query_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=True)

    logger.info("***** inference of passages *****")
    passage_collection_path = os.path.join(args.data_dir, "passages")
    passage_cache = EmbeddingCache(passage_collection_path)
    with passage_cache as emb:
        passage_embedding, passage_embedding2id = StreamInferenceDoc(
            args,
            model,
            GetProcessingFn(args, query=False),
            "passage_" + str(latest_step_num) + "_",
            emb,
            is_query_inference=False,
            load_cache=False)
    logger.info("***** Done passage inference *****")

    if is_first_worker():
        train_pos_id, test_pos_id = preloaded_data
        dim = passage_embedding.shape[1]
        print('passage embedding shape: ' + str(passage_embedding.shape))
        top_k = args.topk_training
        #num_q = passage_embedding.shape[0]

        faiss.omp_set_num_threads(16)
        cpu_index = faiss.IndexFlatIP(dim)
        cpu_index.add(passage_embedding)

        logger.info('Data indexing completed.')

        logger.info("Start searching for query embedding with length %d",
                    len(query_embedding))

        II = list()
        for i in range(15):
            _, idx = cpu_index.search(query_embedding[i * 5000:(i + 1) * 5000],
                                      top_k)  #I: [number of queries, topk]
            II.append(idx)
            logger.info("Split done %d", i)

        I = II[0]
        for i in range(1, 15):
            I = np.concatenate((I, II[i]), axis=0)

        logger.info("***** GenerateNegativePassaageID *****")
        effective_q_id = set(query_embedding2id.flatten())

        #logger.info("Effective qid length %d, search result length %d", len(effective_q_id), I.shape[0])
        query_negative_passage = dict()
        for query_idx in range(I.shape[0]):
            query_id = query_embedding2id[query_idx]
            doc_ids = list()

            doc_ids = [passage_embedding2id[pidx] for pidx in I[query_idx]]

            neg_docs = list()
            for doc_id in doc_ids:
                pos_id = [
                    int(p_id) for p_id in train_pos_id[query_id].split(',')
                ]

                if doc_id in pos_id:
                    continue
                if doc_id in neg_docs:
                    continue
                neg_docs.append(doc_id)

            query_negative_passage[query_id] = neg_docs

        logger.info("Done generating negative passages, output length %d",
                    len(query_negative_passage))

        logger.info("***** Construct ANN Triplet *****")
        train_data_output_path = os.path.join(
            args.output_dir, "ann_training_data_" + str(output_num))

        with open(train_data_output_path, 'w') as f:
            query_range = list(range(I.shape[0]))
            random.shuffle(query_range)
            for query_idx in query_range:
                query_id = query_embedding2id[query_idx]
                # if not query_id in train_pos_id:
                #     continue
                pos_pid = train_pos_id[query_id]
                f.write("{}\t{}\t{}\n".format(
                    query_id, pos_pid, ','.join(
                        str(neg_pid)
                        for neg_pid in query_negative_passage[query_id])))

        _, dev_I = cpu_index.search(dev_query_embedding,
                                    10)  #I: [number of queries, topk]

        top_k_hits = validate(test_pos_id, dev_I, dev_query_embedding2id,
                              passage_embedding2id)