Exemple #1
0
def preprocess(args):

    pid2offset = {}
    if args.data_type == 0:
        in_passage_path = os.path.join(
            args.data_dir,
            "msmarco-docs.tsv",
        )
    else:
        in_passage_path = os.path.join(
            args.data_dir,
            "collection.tsv",
        )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return

    out_line_count = 0

    print('start passage file split processing')
    multi_file_process(args, 32, in_passage_path, out_passage_path,
                       PassagePreprocessingFn)

    print('start merging splits')
    with open(out_passage_path, 'wb') as f:
        for idx, record in enumerate(
                numbered_byte_file_generator(out_passage_path, 32,
                                             8 + 4 + args.max_seq_length * 4)):
            p_id = int.from_bytes(record[:8], 'big')
            f.write(record[8:])
            pid2offset[p_id] = idx
            if idx < 3:
                print(str(idx) + " " + str(p_id))
            out_line_count += 1

    print("Total lines written: " + str(out_line_count))
    meta = {
        'type': 'int32',
        'total_number': out_line_count,
        'embedding_size': args.max_seq_length
    }
    with open(out_passage_path + "_meta", 'w') as f:
        json.dump(meta, f)
    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    pid2offset_path = os.path.join(
        args.out_data_dir,
        "pid2offset.pickle",
    )
    with open(pid2offset_path, 'wb') as handle:
        pickle.dump(pid2offset, handle, protocol=4)
    print("done saving pid2offset")

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "msmarco-doctrain-queries.tsv",
                        "msmarco-doctrain-qrels.tsv", "train-query",
                        "train-qrel.tsv")
        write_query_rel(args, pid2offset, "msmarco-test2019-queries.tsv",
                        "2019qrels-docs.txt", "dev-query", "dev-qrel.tsv")
    else:
        write_query_rel(args, pid2offset, "queries.train.tsv",
                        "qrels.train.tsv", "train-query", "train-qrel.tsv")
        write_query_rel(args, pid2offset, "queries.dev.small.tsv",
                        "qrels.dev.small.tsv", "dev-query", "dev-qrel.tsv")
Exemple #2
0
def write_query_rel(args, pid2offset, query_file, positive_id_file,
                    out_query_file, out_id_file):

    print("Writing query files " + str(out_query_file) + " and " +
          str(out_id_file))
    query_positive_id = set()

    query_positive_id_path = os.path.join(
        args.data_dir,
        positive_id_file,
    )

    print("Loading query_2_pos_docid")
    with gzip.open(query_positive_id_path, 'rt',
                   encoding='utf8') if positive_id_file[-2:] == "gz" else open(
                       query_positive_id_path, 'r', encoding='utf8') as f:
        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        for [topicid, _, docid, rel] in tsvreader:
            query_positive_id.add(int(topicid))

    query_collection_path = os.path.join(
        args.data_dir,
        query_file,
    )

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    qid2offset = {}

    print('start query file split processing')
    multi_file_process(args, 32, query_collection_path, out_query_path,
                       QueryPreprocessingFn)

    print('start merging splits')

    idx = 0
    with open(out_query_path, 'wb') as f:
        for record in numbered_byte_file_generator(
                out_query_path, 32, 8 + 4 + args.max_query_length * 4):
            q_id = int.from_bytes(record[:8], 'big')
            if q_id not in query_positive_id:
                # exclude the query as it is not in label set
                continue
            f.write(record[8:])
            qid2offset[q_id] = idx
            idx += 1
            if idx < 3:
                print(str(idx) + " " + str(q_id))

    qid2offset_path = os.path.join(
        args.out_data_dir,
        "qid2offset.pickle",
    )
    with open(qid2offset_path, 'wb') as handle:
        pickle.dump(qid2offset, handle, protocol=4)
    print("done saving qid2offset")

    print("Total lines written: " + str(idx))
    meta = {
        'type': 'int32',
        'total_number': idx,
        'embedding_size': args.max_query_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    out_id_path = os.path.join(
        args.out_data_dir,
        out_id_file,
    )

    print("Writing qrels")
    with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f, \
            open(out_id_path, "w", encoding='utf-8') as out_id:

        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        out_line_count = 0
        for [topicid, _, docid, rel] in tsvreader:
            topicid = int(topicid)
            if args.data_type == 0:
                docid = int(docid[1:])
            else:
                docid = int(docid)
            out_id.write(
                str(qid2offset[topicid]) + "\t" + str(pid2offset[docid]) +
                "\t" + rel + "\n")
            out_line_count += 1
        print("Total lines written: " + str(out_line_count))
Exemple #3
0
def preprocess(args):

    pid2offset = {}
    in_passage_path = os.path.join(
        args.wiki_dir,
        "psgs_w100.tsv" ,
    )
    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages" ,
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return
    else:
        out_line_count = 0

        print('start passage file split processing')
        multi_file_process(args, 32, in_passage_path, out_passage_path, PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(numbered_byte_file_generator(out_passage_path, 32, 8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1

        print("Total lines written: " + str(out_line_count))
        meta = {'type': 'int32', 'total_number': out_line_count, 'embedding_size': args.max_seq_length}
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)    
        write_mapping(args, pid2offset, "pid2offset")

    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[pid2offset[1]])

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "nq-train.json", "train-query", "train-ann", "train-data")
    elif args.data_type == 1:
        write_query_rel(args, pid2offset, "trivia-train.json", "train-query", "train-ann", "train-data", "psg_id")
    else:
        # use both training dataset and merge them
        write_query_rel(args, pid2offset, "nq-train.json", "train-query-nq", "train-ann-nq", "train-data-nq")
        write_query_rel(args, pid2offset, "trivia-train.json", "train-query-trivia", "train-ann-trivia", "train-data-trivia", "psg_id")

        with open(args.out_data_dir + "train-query-nq", "rb") as nq_query, \
                open(args.out_data_dir + "train-query-trivia", "rb") as trivia_query, \
                open(args.out_data_dir + "train-query", "wb") as out_query:
            out_query.write(nq_query.read())
            out_query.write(trivia_query.read())

        with open(args.out_data_dir + "train-query-nq_meta", "r", encoding='utf-8') as nq_query, \
                open(args.out_data_dir + "train-query-trivia_meta", "r", encoding='utf-8') as trivia_query, \
                open(args.out_data_dir + "train-query_meta", "w", encoding='utf-8') as out_query:
            a = json.load(nq_query)
            b = json.load(trivia_query)
            meta = {'type': 'int32', 'total_number': a['total_number'] + b['total_number'], 'embedding_size': args.max_seq_length}
            json.dump(meta, out_query)

        embedding_cache = EmbeddingCache(args.out_data_dir + "train-query")
        print("First line after merge")
        with embedding_cache as emb:
            print(emb[58812])

        with open(args.out_data_dir + "train-ann-nq", "r", encoding='utf-8') as nq_ann, \
                open(args.out_data_dir + "train-ann-trivia", "r", encoding='utf-8') as trivia_ann, \
                open(args.out_data_dir + "train-ann", "w", encoding='utf-8') as out_ann:
            out_ann.writelines(nq_ann.readlines())
            out_ann.writelines(trivia_ann.readlines())


    write_query_rel(args, pid2offset, "nq-dev.json", "dev-query", "dev-ann", "dev-data")
    write_query_rel(args, pid2offset, "trivia-dev.json", "dev-query-trivia", "dev-ann-trivia", "dev-data-trivia", "psg_id")
    write_qas_query(args, "nq-test.csv", "test-query")
    write_qas_query(args, "trivia-test.csv", "trivia-test-query")