Ejemplo n.º 1
0
def merge_with_es(query_data, question_data, top_n=5):
    out_data = []

    for chunk in list(chunks(question_data, 100)):
        queries = []
        for datum in chunk:
            _id = datum['_id']
            queries.append(query_data[_id] if isinstance(query_data[_id], str)
                           else query_data[_id][0][0])

        es_results = bulk_text_query(queries, topn=top_n, lazy=False)
        for es_result, datum in zip(es_results, chunk):
            _id = datum['_id']
            question_t = datum['question']
            query = query_data[_id] if isinstance(
                query_data[_id], str) else query_data[_id][0][0]
            context = make_context(question_t, es_result)
            json_context = [[p['title'], p['data_object']['text']]
                            for p in es_result]

            out_data.append({
                '_id': _id,
                'question': question_t,
                'context': context,
                'query': query,
                'json_context': json_context
            })
    print("查询es完毕")
    return out_data
Ejemplo n.º 2
0
def main(query_file, question_file, out_file, top_n):
    query_data = load_json_file(query_file)
    question_data = load_json_file(question_file)

    out_data = []

    for chunk in tqdm(list(chunks(question_data, 100))):
        queries = []
        for datum in chunk:
            _id = datum['_id']
            queries.append(query_data[_id] if isinstance(query_data[_id], str)
                           else query_data[_id][0][0])

        es_results = bulk_text_query(queries, topn=top_n, lazy=False)
        for es_result, datum in zip(es_results, chunk):
            _id = datum['_id']
            question = datum['question']
            query = query_data[_id] if isinstance(
                query_data[_id], str) else query_data[_id][0][0]
            context = make_context(question, es_result)
            json_context = [[p['title'], p['data_object']['text']]
                            for p in es_result]

            out_data.append({
                '_id': _id,
                'question': question,
                'context': context,
                'query': query,
                'json_context': json_context
            })

        write_json_file(out_data, out_file)
Ejemplo n.º 3
0
def main():
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('split', choices=['train', 'dev'])

    args = parser.parse_args()

    if args.split == 'train':
        filename = 'data/hotpotqa/hotpot_train_v1.1.json'
        outputname = 'data/hotpotqa/hotpot_train_single_hop.json'
    else:
        filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json'
        outputname = 'data/hotpotqa/hotpot_dev_single_hop.json'
    batch_size = 64

    with open(filename) as f:
        data = json.load(f)

    outputdata = []
    processed = 0
    for batch in tqdm(chunks(data, batch_size), total=(len(data) + batch_size - 1) // batch_size):
        queries = [x['question'] for x in batch]
        res = bulk_text_query(queries, topn=10, lazy=False)
        for r, d in zip(res, batch):
            d1 = copy(d)
            context = [item['data_object'] for item in r]
            context = [(x['title'], x['text']) for x in context]
            d1['context'] = context
            outputdata.append(d1)

        processed += len(batch)

    with open(outputname, 'w') as f:
        json.dump(outputdata, f)
Ejemplo n.º 4
0
def main():
    import argparse
    parser = argparse.ArgumentParser()

    parser.add_argument('split', choices=['train', 'dev'])

    args = parser.parse_args()

    if args.split == 'train':
        filename = 'data/hotpotqa/hotpot_train_v1.1.json'
    else:
        filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json'
    batch_size = 64
    Ns = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50]
    max_n = max(Ns)

    with open(filename) as f:
        data = json.load(f)

    batches = [[
        (x['question'], set(y[0] for y in x['supporting_facts']))
        for x in data[b * batch_size:min((b + 1) * batch_size, len(data))]
    ] for b in range((len(data) + batch_size - 1) // batch_size)]

    para1 = Counter()
    para2 = Counter()
    processed = 0
    for batch in tqdm(batches):
        queries = [x[0] for x in batch]
        res = bulk_text_query(queries, topn=max_n, lazy=True)
        # set lazy to true because we don't really care about the json object here
        for r, d in zip(res, batch):
            para1_found = False
            para2_found = False
            for i, para in enumerate(r):
                if para['title'] in d[1]:
                    if not para1_found:
                        para1[i] += 1
                        para1_found = True
                    else:
                        assert not para2_found
                        para2[i] += 1
                        para2_found = True

            if not para1_found:
                para1[max_n] += 1
            if not para2_found:
                para2[max_n] += 1

        processed += len(batch)

    for n in Ns:
        c1 = sum(para1[k] for k in range(n))
        c2 = sum(para2[k] for k in range(n))

        print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format(
            n, 100 * (c1 + c2) / 2 / processed, n, 100 * c1 / processed, n,
            100 * c2 / processed))
Ejemplo n.º 5
0
def analyze(hop2_results):
    batch_size = 128
    Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50]
    max_n = max(Ns)
    p1_hits = Counter()
    p2_hits = Counter()
    processed = 0

    for chunk in tqdm(chunks(hop2_results, batch_size)):

        label2s = [x['label'] for x in chunk]
        es_bulk_results = bulk_text_query(label2s, topn=max_n, lazy=False)

        for i, (entry, es_results) in enumerate(zip(chunk, es_bulk_results)):
            q = entry['question']
            l2 = entry['label']
            t1 = entry['title1']
            p1 = entry['para1']
            t2 = entry['title2']
            p2 = entry['para2']

            # find rank of t1 in es_results
            found_t1 = False
            found_t2 = False
            t2_rank = max_n
            for i, es_entry in enumerate(es_results):
                if es_entry['title'] == t1:
                    p1_hits[i] += 1
                    found_t1 = True
                if es_entry['title'] == t2:
                    p2_hits[i] += 1
                    t2_rank = i
                    found_t2 = True
            if not found_t1:
                p1_hits[max_n] += 1
            if not found_t2:
                p2_hits[max_n] += 1

            print_cols = [q, l2, t1, p1, t2, p2, str(t2_rank + 1)]
            #print('\t'.join(print_cols))
            processed += 1

    for n in Ns:
        c1 = sum(p1_hits[k] for k in range(n))
        c2 = sum(p2_hits[k] for k in range(n))

        print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format(
            n, 100 * (c1+c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed))
Ejemplo n.º 6
0
def deduped_bulk_query(queries1, topn=10, lazy=True):
    # consolidate queries to remove redundancy
    queries2 = []
    queries2_dict = dict()
    mapped_idx = []
    for q in queries1:
        if q not in queries2_dict:
            queries2_dict[q] = len(queries2)
            queries2.append(q)
        mapped_idx.append(queries2_dict[q])

    res1 = bulk_text_query(queries2, topn=topn, lazy=lazy)

    # map queries back
    res = [res1[idx] for idx in mapped_idx]

    return res