def merge_with_es(query_data, question_data, top_n=5): out_data = [] for chunk in list(chunks(question_data, 100)): queries = [] for datum in chunk: _id = datum['_id'] queries.append(query_data[_id] if isinstance(query_data[_id], str) else query_data[_id][0][0]) es_results = bulk_text_query(queries, topn=top_n, lazy=False) for es_result, datum in zip(es_results, chunk): _id = datum['_id'] question_t = datum['question'] query = query_data[_id] if isinstance( query_data[_id], str) else query_data[_id][0][0] context = make_context(question_t, es_result) json_context = [[p['title'], p['data_object']['text']] for p in es_result] out_data.append({ '_id': _id, 'question': question_t, 'context': context, 'query': query, 'json_context': json_context }) print("查询es完毕") return out_data
def main(query_file, question_file, out_file, top_n): query_data = load_json_file(query_file) question_data = load_json_file(question_file) out_data = [] for chunk in tqdm(list(chunks(question_data, 100))): queries = [] for datum in chunk: _id = datum['_id'] queries.append(query_data[_id] if isinstance(query_data[_id], str) else query_data[_id][0][0]) es_results = bulk_text_query(queries, topn=top_n, lazy=False) for es_result, datum in zip(es_results, chunk): _id = datum['_id'] question = datum['question'] query = query_data[_id] if isinstance( query_data[_id], str) else query_data[_id][0][0] context = make_context(question, es_result) json_context = [[p['title'], p['data_object']['text']] for p in es_result] out_data.append({ '_id': _id, 'question': question, 'context': context, 'query': query, 'json_context': json_context }) write_json_file(out_data, out_file)
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('split', choices=['train', 'dev']) args = parser.parse_args() if args.split == 'train': filename = 'data/hotpotqa/hotpot_train_v1.1.json' outputname = 'data/hotpotqa/hotpot_train_single_hop.json' else: filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json' outputname = 'data/hotpotqa/hotpot_dev_single_hop.json' batch_size = 64 with open(filename) as f: data = json.load(f) outputdata = [] processed = 0 for batch in tqdm(chunks(data, batch_size), total=(len(data) + batch_size - 1) // batch_size): queries = [x['question'] for x in batch] res = bulk_text_query(queries, topn=10, lazy=False) for r, d in zip(res, batch): d1 = copy(d) context = [item['data_object'] for item in r] context = [(x['title'], x['text']) for x in context] d1['context'] = context outputdata.append(d1) processed += len(batch) with open(outputname, 'w') as f: json.dump(outputdata, f)
def main(): import argparse parser = argparse.ArgumentParser() parser.add_argument('split', choices=['train', 'dev']) args = parser.parse_args() if args.split == 'train': filename = 'data/hotpotqa/hotpot_train_v1.1.json' else: filename = 'data/hotpotqa/hotpot_dev_fullwiki_v1.json' batch_size = 64 Ns = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 15, 20, 25, 30, 35, 40, 45, 50] max_n = max(Ns) with open(filename) as f: data = json.load(f) batches = [[ (x['question'], set(y[0] for y in x['supporting_facts'])) for x in data[b * batch_size:min((b + 1) * batch_size, len(data))] ] for b in range((len(data) + batch_size - 1) // batch_size)] para1 = Counter() para2 = Counter() processed = 0 for batch in tqdm(batches): queries = [x[0] for x in batch] res = bulk_text_query(queries, topn=max_n, lazy=True) # set lazy to true because we don't really care about the json object here for r, d in zip(res, batch): para1_found = False para2_found = False for i, para in enumerate(r): if para['title'] in d[1]: if not para1_found: para1[i] += 1 para1_found = True else: assert not para2_found para2[i] += 1 para2_found = True if not para1_found: para1[max_n] += 1 if not para2_found: para2[max_n] += 1 processed += len(batch) for n in Ns: c1 = sum(para1[k] for k in range(n)) c2 = sum(para2[k] for k in range(n)) print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( n, 100 * (c1 + c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed))
def analyze(hop2_results): batch_size = 128 Ns = [1,2,3,4,5,6,7,8,9,10,15,20,25,30,35,40,45,50] max_n = max(Ns) p1_hits = Counter() p2_hits = Counter() processed = 0 for chunk in tqdm(chunks(hop2_results, batch_size)): label2s = [x['label'] for x in chunk] es_bulk_results = bulk_text_query(label2s, topn=max_n, lazy=False) for i, (entry, es_results) in enumerate(zip(chunk, es_bulk_results)): q = entry['question'] l2 = entry['label'] t1 = entry['title1'] p1 = entry['para1'] t2 = entry['title2'] p2 = entry['para2'] # find rank of t1 in es_results found_t1 = False found_t2 = False t2_rank = max_n for i, es_entry in enumerate(es_results): if es_entry['title'] == t1: p1_hits[i] += 1 found_t1 = True if es_entry['title'] == t2: p2_hits[i] += 1 t2_rank = i found_t2 = True if not found_t1: p1_hits[max_n] += 1 if not found_t2: p2_hits[max_n] += 1 print_cols = [q, l2, t1, p1, t2, p2, str(t2_rank + 1)] #print('\t'.join(print_cols)) processed += 1 for n in Ns: c1 = sum(p1_hits[k] for k in range(n)) c2 = sum(p2_hits[k] for k in range(n)) print("Hits@{:2d}: {:.2f}\tP1@{:2d}: {:.2f}\tP2@{:2d}: {:.2f}".format( n, 100 * (c1+c2) / 2 / processed, n, 100 * c1 / processed, n, 100 * c2 / processed))
def deduped_bulk_query(queries1, topn=10, lazy=True): # consolidate queries to remove redundancy queries2 = [] queries2_dict = dict() mapped_idx = [] for q in queries1: if q not in queries2_dict: queries2_dict[q] = len(queries2) queries2.append(q) mapped_idx.append(queries2_dict[q]) res1 = bulk_text_query(queries2, topn=topn, lazy=lazy) # map queries back res = [res1[idx] for idx in mapped_idx] return res