Exemple #1
0
def preprocess(args):

    pid2offset = {}
    in_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if False:
        print("preprocessed data already exist, exit preprocessing")
        return
    else:
        out_line_count = 0

        print('start passage file split processing')
        print(args.model_type)
        #multi_file_process(args, 32, in_passage_path, out_passage_path, PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        in_passage_path, 53, 8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1
                print(out_line_count)

        print("Total lines written: " + str(out_line_count))
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        write_mapping(args, pid2offset, "pid2offset")
    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[pid2offset[1]])

    write_qas_query(args, pid2offset, "hotpot_train_first_hop.json",
                    "train-query", "train-ann")
    write_qas_query(args, pid2offset, "hotpot_dev_first_hop.json", "dev-query",
                    "dev-ann")
    write_qas_query(args, pid2offset, "hotpot_train_sec_hop.json",
                    "train-sec-query", "train-sec-ann")
    write_qas_query(args, pid2offset, "hotpot_dev_sec_hop.json",
                    "dev-sec-query", "dev-sec-ann")
def write_query_norel(args, prefix, query_file, out_query_file):

    query_collection_path = os.path.join(
        args.data_dir,
        query_file,
    )

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    qid2offset = {}

    print('start query file split processing')
    multi_file_process(args, args.n_split_process, query_collection_path,
                       out_query_path, QueryPreprocessingFn)

    print('start merging splits')

    idx = 0
    with open(out_query_path, 'wb') as f:
        for record in numbered_byte_file_generator(
                out_query_path, args.n_split_process,
                8 + 4 + args.max_query_length * 4):
            q_id = int.from_bytes(record[:8], 'big')
            f.write(record[8:])
            qid2offset[q_id] = idx
            idx += 1
            if idx < 3:
                print(str(idx) + " " + str(q_id))
    for i in range(args.n_split_process):
        os.remove('{}_split{}'.format(out_query_path,
                                      i))  # delete intermediate files

    qid2offset_path = os.path.join(
        args.out_data_dir,
        prefix + "_qid2offset.pickle",
    )
    with open(qid2offset_path, 'wb') as handle:
        pickle.dump(qid2offset, handle, protocol=4)
    print("done saving qid2offset")

    print("Total lines written: " + str(idx))
    meta = {
        'type': 'int32',
        'total_number': idx,
        'embedding_size': args.max_query_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])
Exemple #3
0
def preprocess(args):

    pid2offset = {}
    in_passage_path = os.path.join(
        args.wiki_dir,
        "psgs_w100.tsv",
    )
    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return
    else:
        out_line_count = 0

        print('start passage file split processing')
        multi_file_process(args, 32, in_passage_path, out_passage_path,
                           PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        out_passage_path, 32,
                        8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1

        print("Total lines written: " + str(out_line_count))
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        write_mapping(args, pid2offset, "pid2offset")

    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[pid2offset[1]])

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "nq-train.json", "train-query",
                        "train-ann", "train-data")
    elif args.data_type == 1:
        write_query_rel(args, pid2offset, "trivia-train.json", "train-query",
                        "train-ann", "train-data", "psg_id")
    else:
        # use both training dataset and merge them
        write_query_rel(args, pid2offset, "nq-train.json", "train-query-nq",
                        "train-ann-nq", "train-data-nq")
        write_query_rel(args, pid2offset, "trivia-train.json",
                        "train-query-trivia", "train-ann-trivia",
                        "train-data-trivia", "psg_id")

        with open(args.out_data_dir + "train-query-nq", "rb") as nq_query, \
                open(args.out_data_dir + "train-query-trivia", "rb") as trivia_query, \
                open(args.out_data_dir + "train-query", "wb") as out_query:
            out_query.write(nq_query.read())
            out_query.write(trivia_query.read())

        with open(args.out_data_dir + "train-query-nq_meta", "r", encoding='utf-8') as nq_query, \
                open(args.out_data_dir + "train-query-trivia_meta", "r", encoding='utf-8') as trivia_query, \
                open(args.out_data_dir + "train-query_meta", "w", encoding='utf-8') as out_query:
            a = json.load(nq_query)
            b = json.load(trivia_query)
            meta = {
                'type': 'int32',
                'total_number': a['total_number'] + b['total_number'],
                'embedding_size': args.max_seq_length
            }
            json.dump(meta, out_query)

        embedding_cache = EmbeddingCache(args.out_data_dir + "train-query")
        print("First line after merge")
        with embedding_cache as emb:
            print(emb[58812])

        with open(args.out_data_dir + "train-ann-nq", "r", encoding='utf-8') as nq_ann, \
                open(args.out_data_dir + "train-ann-trivia", "r", encoding='utf-8') as trivia_ann, \
                open(args.out_data_dir + "train-ann", "w", encoding='utf-8') as out_ann:
            out_ann.writelines(nq_ann.readlines())
            out_ann.writelines(trivia_ann.readlines())

    write_query_rel(args, pid2offset, "nq-dev.json", "dev-query", "dev-ann",
                    "dev-data")
    write_query_rel(args, pid2offset, "trivia-dev.json", "dev-query-trivia",
                    "dev-ann-trivia", "dev-data-trivia", "psg_id")
    write_qas_query(args, "nq-test.csv", "test-query")
    write_qas_query(args, "trivia-test.csv", "trivia-test-query")
Exemple #4
0
def write_query_rel(args, pid2offset, query_file, positive_id_file,
                    out_query_file, out_id_file):

    # args,
    # pid2offset,
    # "msmarco-test2019-queries.tsv",
    # "2019qrels-docs.txt",
    # "dev-query",
    # "dev-qrel.tsv"
    print("Writing query files " + str(out_query_file) + " and " +
          str(out_id_file))
    query_positive_id = set()

    query_positive_id_path = os.path.join(
        args.data_dir,
        positive_id_file,
    )

    print("Loading query_2_pos_docid")
    with gzip.open(query_positive_id_path, 'rt',
                   encoding='utf8') if positive_id_file[-2:] == "gz" else open(
                       query_positive_id_path, 'r', encoding='utf8') as f:
        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        for [topicid, _, docid, rel] in tsvreader:
            query_positive_id.add(int(topicid))

    query_collection_path = os.path.join(
        args.data_dir,
        query_file,
    )

    out_query_path = os.path.join(
        args.out_data_dir,
        out_query_file,
    )

    qid2offset = {}

    print('start query file split processing')
    multi_file_process(args, 32, query_collection_path, out_query_path,
                       QueryPreprocessingFn)

    print('start merging splits')

    idx = 0
    with open(out_query_path, 'wb') as f:
        for record in numbered_byte_file_generator(
                out_query_path, 32, 8 + 4 + args.max_query_length * 4):
            q_id = int.from_bytes(record[:8], 'big')
            if q_id not in query_positive_id:
                # exclude the query as it is not in label set
                continue
            f.write(record[8:])
            qid2offset[q_id] = idx
            idx += 1
            if idx < 3:
                print(str(idx) + " " + str(q_id))

    qid2offset_path = os.path.join(
        args.out_data_dir,
        "qid2offset.pickle",
    )
    with open(qid2offset_path, 'wb') as handle:
        pickle.dump(qid2offset, handle, protocol=4)
    print("done saving qid2offset")

    print("Total lines written: " + str(idx))
    meta = {
        'type': 'int32',
        'total_number': idx,
        'embedding_size': args.max_query_length
    }
    with open(out_query_path + "_meta", 'w') as f:
        json.dump(meta, f)

    embedding_cache = EmbeddingCache(out_query_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    out_id_path = os.path.join(
        args.out_data_dir,
        out_id_file,
    )

    print("Writing qrels")
    with gzip.open(query_positive_id_path, 'rt', encoding='utf8') if positive_id_file[-2:] == "gz" else open(query_positive_id_path, 'r', encoding='utf8') as f, \
            open(out_id_path, "w", encoding='utf-8') as out_id:

        if args.data_type == 0:
            tsvreader = csv.reader(f, delimiter=" ")
        else:
            tsvreader = csv.reader(f, delimiter="\t")
        out_line_count = 0
        for [topicid, _, docid, rel] in tsvreader:
            topicid = int(topicid)
            if args.data_type == 0:
                docid = int(docid[1:])
            else:
                docid = int(docid)
            out_id.write(
                str(qid2offset[topicid]) + "\t" + str(pid2offset[docid]) +
                "\t" + rel + "\n")
            out_line_count += 1
        print("Total lines written: " + str(out_line_count))
Exemple #5
0
def preprocess(args):

    pid2offset = {}
    if args.data_type == 0:
        in_passage_path = os.path.join(
            args.data_dir,
            "msmarco-docs.tsv",
        )
    else:
        in_passage_path = os.path.join(
            args.data_dir,
            "collection.tsv",
        )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )

    if os.path.exists(out_passage_path):
        print("preprocessed data already exist, exit preprocessing")
        return

    out_line_count = 0

    print('start passage file split processing')
    multi_file_process(args, 32, in_passage_path, out_passage_path,
                       PassagePreprocessingFn)

    print('start merging splits')
    with open(out_passage_path, 'wb') as f:
        for idx, record in enumerate(
                numbered_byte_file_generator(out_passage_path, 32,
                                             8 + 4 + args.max_seq_length * 4)):
            p_id = int.from_bytes(record[:8], 'big')
            f.write(record[8:])
            pid2offset[p_id] = idx
            if idx < 3:
                print(str(idx) + " " + str(p_id))
            out_line_count += 1

    print("Total lines written: " + str(out_line_count))
    meta = {
        'type': 'int32',
        'total_number': out_line_count,
        'embedding_size': args.max_seq_length
    }
    with open(out_passage_path + "_meta", 'w') as f:
        json.dump(meta, f)
    embedding_cache = EmbeddingCache(out_passage_path)
    print("First line")
    with embedding_cache as emb:
        print(emb[0])

    pid2offset_path = os.path.join(
        args.out_data_dir,
        "pid2offset.pickle",
    )
    with open(pid2offset_path, 'wb') as handle:
        pickle.dump(pid2offset, handle, protocol=4)
    print("done saving pid2offset")

    if args.data_type == 0:
        write_query_rel(args, pid2offset, "msmarco-doctrain-queries.tsv",
                        "msmarco-doctrain-qrels.tsv", "train-query",
                        "train-qrel.tsv")
        write_query_rel(args, pid2offset, "msmarco-test2019-queries.tsv",
                        "2019qrels-docs.txt", "dev-query", "dev-qrel.tsv")
    else:
        write_query_rel(args, pid2offset, "queries.train.tsv",
                        "qrels.train.tsv", "train-query", "train-qrel.tsv")
        write_query_rel(args, pid2offset, "queries.dev.small.tsv",
                        "qrels.dev.small.tsv", "dev-query", "dev-qrel.tsv")
def preprocess(args):

    pid2offset = {}

    in_passage_path = os.path.join(
        args.data_dir,
        args.doc_collection_tsv,
    )

    out_passage_path = os.path.join(
        args.out_data_dir,
        "passages",
    )
    pid2offset_path = os.path.join(
        args.out_data_dir,
        "pid2offset.pickle",
    )
    if os.path.isfile(out_passage_path):
        print("preprocessed passage data already exist, skip")
        with open(pid2offset_path, 'rb') as handle:
            pid2offset = pickle.load(handle)
    else:
        out_line_count = 0

        print('start passage file split processing')
        multi_file_process(args, args.n_split_process, in_passage_path,
                           out_passage_path, PassagePreprocessingFn)

        print('start merging splits')
        with open(out_passage_path, 'wb') as f:
            for idx, record in enumerate(
                    numbered_byte_file_generator(
                        out_passage_path, args.n_split_process,
                        8 + 4 + args.max_seq_length * 4)):
                p_id = int.from_bytes(record[:8], 'big')
                f.write(record[8:])
                pid2offset[p_id] = idx
                if idx < 3:
                    print(str(idx) + " " + str(p_id))
                out_line_count += 1
        print("Total lines written: " + str(out_line_count))
        for i in range(args.n_split_process):
            os.remove('{}_split{}'.format(out_passage_path,
                                          i))  # delete intermediate files
        meta = {
            'type': 'int32',
            'total_number': out_line_count,
            'embedding_size': args.max_seq_length
        }
        with open(out_passage_path + "_meta", 'w') as f:
            json.dump(meta, f)
        embedding_cache = EmbeddingCache(out_passage_path)
        print("First line")
        with embedding_cache as emb:
            print(emb[0])

        with open(pid2offset_path, 'wb') as handle:
            pickle.dump(pid2offset, handle, protocol=4)
        print("done saving pid2offset")

    if args.qrel_tsv is not None:
        write_query_rel(args, pid2offset, args.query_collection_tsv,
                        args.qrel_tsv, "{}-query".format(args.save_prefix),
                        "{}-qrel.tsv".format(args.save_prefix))
    else:
        write_query_norel(
            args,
            args.save_prefix,
            args.query_collection_tsv,
            "{}-query".format(args.save_prefix),
        )