def __init__(self, args): self.args = args self.base_ip = args.base_ip self.query_port = args.query_port self.index_port = args.index_port self.truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], args.truecase_path))
def load_qa_pairs(data_path, args, draft_num_examples=1000, shuffle=False): q_ids = [] questions = [] answers = [] data = json.load(open(data_path))['data'] for item in data: q_id = item['id'] question = item['question'] answer = item['answers'] if len(answer) == 0: continue q_ids.append(q_id) questions.append(question) answers.append(answer) questions = [ query[:-1] if query.endswith('?') else query for query in questions ] if args.truecase: try: logger.info('Loading truecaser for queries') truecase = TrueCaser( os.path.join(os.environ['DPH_DATA_DIR'], args.truecase_path)) questions = [ truecase.get_true_case(query) if query == query.lower() else query for query in questions ] except Exception as e: print(e) if args.do_lower_case: logger.info(f'Lowercasing queries') questions = [query.lower() for query in questions] if args.draft: q_ids = np.array(q_ids)[:draft_num_examples].tolist() questions = np.array(questions)[:draft_num_examples].tolist() answers = np.array(answers)[:draft_num_examples].tolist() if shuffle: qa_pairs = list(zip(q_ids, questions, answers)) random.shuffle(qa_pairs) q_ids, questions, answers = zip(*qa_pairs) logger.info(f'Shuffling QA pairs') logger.info(f'Loading {len(questions)} questions from {data_path}') logger.info(f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}') return q_ids, questions, answers
def __init__(self, load_dir, dump_dir, index_name='start/1048576_flat_OPQ96', device='cuda', verbose=False, **kwargs): print( "This could take up to 15 mins depending on the file reading speed of HDD/SSD" ) # Turn off loggers if not verbose: logging.getLogger("densephrases").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.WARNING) # Get default options options = Options() options.add_model_options() options.add_index_options() options.add_retrieval_options() options.add_data_options() self.args = options.parse() # Set options self.args.load_dir = load_dir self.args.dump_dir = dump_dir self.args.cache_dir = os.environ['CACHE_DIR'] self.args.index_name = index_name self.args.cuda = True if device == 'cuda' else False self.args.__dict__.update(kwargs) # Load encoder self.set_encoder(load_dir, device) # Load MIPS self.mips = load_phrase_index(self.args, ignore_logging=not verbose) # Others self.truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], self.args.truecase_path)) print("Loading DensePhrases Completed!")
class DensePhrases(object): def __init__(self, load_dir, dump_dir, index_name='start/1048576_flat_OPQ96', device='cuda', verbose=False, **kwargs): print( "This could take up to 15 mins depending on the file reading speed of HDD/SSD" ) # Turn off loggers if not verbose: logging.getLogger("densephrases").setLevel(logging.WARNING) logging.getLogger("transformers").setLevel(logging.WARNING) # Get default options options = Options() options.add_model_options() options.add_index_options() options.add_retrieval_options() options.add_data_options() self.args = options.parse() # Set options self.args.load_dir = load_dir self.args.dump_dir = dump_dir self.args.cache_dir = os.environ['CACHE_DIR'] self.args.index_name = index_name self.args.cuda = True if device == 'cuda' else False self.args.__dict__.update(kwargs) # Load encoder self.set_encoder(load_dir, device) # Load MIPS self.mips = load_phrase_index(self.args, ignore_logging=not verbose) # Others self.truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], self.args.truecase_path)) print("Loading DensePhrases Completed!") def search(self, query='', retrieval_unit='phrase', top_k=10, truecase=True, return_meta=False): # If query is str, single query single_query = False if type(query) == str: batch_query = [query] single_query = True else: assert type(query) == list batch_query = query # Pre-processing if truecase: query = [ self.truecase.get_true_case(query) if query == query.lower() else query for query in batch_query ] # Get question vector outs = self.query2vec(batch_query) start = np.concatenate([out[0] for out in outs], 0) end = np.concatenate([out[1] for out in outs], 0) query_vec = np.concatenate([start, end], 1) # Search agg_strats = { 'phrase': 'opt1', 'sentence': 'opt2', 'paragraph': 'opt2', 'document': 'opt3' } if retrieval_unit not in agg_strats: raise NotImplementedError( f'"{retrieval_unit}" not supported. Choose one of {agg_strats.keys()}.' ) search_top_k = top_k if retrieval_unit in ['sentece', 'paragraph', 'document']: search_top_k *= 2 rets = self.mips.search( query_vec, q_texts=batch_query, nprobe=256, top_k=search_top_k, max_answer_length=10, return_idxs=False, aggregate=True, agg_strat=agg_strats[retrieval_unit], return_sent=True if retrieval_unit == 'sentence' else False) # Gather results rets = [ret[:top_k] for ret in rets] if retrieval_unit == 'phrase': retrieved = [[rr['answer'] for rr in ret][:top_k] for ret in rets] elif retrieval_unit == 'sentence': retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets] elif retrieval_unit == 'paragraph': retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets] elif retrieval_unit == 'document': retrieved = [[rr['title'][0] for rr in ret][:top_k] for ret in rets] else: raise NotImplementedError() if single_query: rets = rets[0] retrieved = retrieved[0] if return_meta: return retrieved, rets else: return retrieved def set_encoder(self, load_dir, device='cuda'): self.args.load_dir = load_dir self.model, self.tokenizer, self.config = load_encoder( device, self.args) self.query2vec = get_query2vec(query_encoder=self.model, tokenizer=self.tokenizer, args=self.args, batch_size=64) def evaluate(self, test_path, **kwargs): from eval_phrase_retrieval import evaluate as evaluate_fn # Set new arguments new_args = copy.deepcopy(self.args) new_args.test_path = test_path new_args.truecase = True new_args.__dict__.update(kwargs) # Run with new_arg evaluate_fn(new_args, self.mips, self.model, self.tokenizer)
class DensePhrasesDemo(object): def __init__(self, args): self.args = args self.base_ip = args.base_ip self.query_port = args.query_port self.index_port = args.index_port self.truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], args.truecase_path)) def serve_query_encoder(self, query_port, args, inmemory=False, batch_size=64, query_encoder=None, tokenizer=None): device = 'cuda' if args.cuda else 'cpu' if query_encoder is None: query_encoder, tokenizer, _ = load_encoder(device, args) query2vec = get_query2vec(query_encoder=query_encoder, tokenizer=tokenizer, args=args, batch_size=batch_size) # Serve query encoder app = Flask(__name__) app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False CORS(app) @app.route('/query2vec_api', methods=['POST']) def query2vec_api(): batch_query = json.loads(request.form['query']) start_time = time() outs = query2vec(batch_query) # logger.info(f'query2vec {time()-start_time:.3f} for {len(batch_query)} queries: {batch_query[0]}') return jsonify(outs) logger.info( f'Starting QueryEncoder server at {self.get_address(query_port)}') http_server = HTTPServer(WSGIContainer(app)) http_server.listen(query_port) IOLoop.instance().start() def serve_phrase_index(self, index_port, args): args.examples_path = os.path.join('densephrases/demo/static', args.examples_path) # Load mips mips = load_phrase_index(args) app = Flask(__name__, static_folder='./densephrases/demo/static/') app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False CORS(app) def batch_search(batch_query, max_answer_length=20, top_k=10, nprobe=64, return_idxs=False): t0 = time() outs, _ = self.embed_query(batch_query)() start = np.concatenate([out[0] for out in outs], 0) end = np.concatenate([out[1] for out in outs], 0) query_vec = np.concatenate([start, end], 1) rets = mips.search( query_vec, q_texts=batch_query, nprobe=nprobe, top_k=top_k, max_answer_length=max_answer_length, return_idxs=return_idxs, aggregate=True, ) for ret_idx, ret in enumerate(rets): for rr in ret: rr['query_tokens'] = outs[ret_idx][2] t1 = time() out = {'ret': rets, 'time': int(1000 * (t1 - t0))} return out @app.route('/') def index(): return app.send_static_file('index.html') @app.route('/files/<path:path>') def static_files(path): return app.send_static_file('files/' + path) # This one uses a default hyperparameters (for Demo) @app.route('/api', methods=['GET']) def api(): query = request.args['query'] query = query[:-1] if query.endswith('?') else query if args.truecase: if query[1:].lower() == query[1:]: query = self.truecase.get_true_case(query) out = batch_search( [query], max_answer_length=args.max_answer_length, top_k=args.top_k, nprobe=args.nprobe, ) out['ret'] = out['ret'][0] return jsonify(out) @app.route('/batch_api', methods=['POST']) def batch_api(): batch_query = json.loads(request.form['query']) max_answer_length = int(request.form['max_answer_length']) top_k = int(request.form['top_k']) nprobe = int(request.form['nprobe']) out = batch_search( batch_query, max_answer_length=max_answer_length, top_k=top_k, nprobe=nprobe, ) return jsonify(out) @app.route('/get_examples', methods=['GET']) def get_examples(): with open(args.examples_path, 'r') as fp: examples = [line.strip() for line in fp.readlines()] return jsonify(examples) if self.query_port is None: logger.info( 'You must set self.query_port for querying. You can use self.update_query_port() later on.' ) logger.info(f'Starting Index server at {self.get_address(index_port)}') http_server = HTTPServer(WSGIContainer(app)) http_server.listen(index_port) IOLoop.instance().start() def serve_bert_encoder(self, bert_port, args): device = 'cuda' if args.cuda else 'cpu' # bert_encoder, tokenizer, _ = load_encoder(device, args) # will be just a bert as query_encoder bert_encoder, tokenizer = load_cross_encoder(device, args) import binascii def float_to_hex(vals): strs = [] # offset = -40. # scale = 5. minv = min(vals) maxv = max(vals) for val in vals: strs.append('{0:0{1}X}'.format( int(min((val - minv) / (maxv - minv) * 255, 255)), 2)) return strs # Define query to vector function def context_query_to_logit(context, query): bert_encoder.eval() # Phrase encoding style dataloader, examples, features = get_cq_dataloader( [context], [query], tokenizer, args.max_query_length, batch_size=64) cq_results = get_cq_results(examples, features, dataloader, device, bert_encoder, batch_size=64) outs = [] for cq_idx, cq_result in enumerate(cq_results): # import pdb; pdb.set_trace() all_logits = ( np.expand_dims(np.array(cq_result.start_logits), axis=1) + np.expand_dims(np.array(cq_result.end_logits), axis=0)).max(1).tolist() out = { 'context': ' '.join(features[cq_idx].tokens[0:]), 'title': 'dummy', 'start_logits': float_to_hex(all_logits[0:len(features[cq_idx].tokens)]), 'end_logits': float_to_hex( cq_result.end_logits[0:len(features[cq_idx].tokens)]), } outs.append(out) return outs def context_query_to_answer(batch_context, batch_query): bert_encoder.eval() # Phrase encoding style dataloader, examples, features = get_bertqa_dataloader( batch_context, batch_query, tokenizer, args.max_query_length, batch_size=64) cq_results = get_bertqa_results(examples, features, dataloader, device, bert_encoder, batch_size=64) predictions, stat = compute_predictions_logits( examples, features, cq_results, 20, 10, False, '', '', '', False, False, 0.0, tokenizer, -1e8, '', ) return predictions # Serve query encoder app = Flask(__name__) app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False CORS(app) @app.route('/') def index(): return app.send_static_file('index_single.html') @app.route('/files/<path:path>') def static_files(path): return app.send_static_file('files/' + path) args.examples_path = os.path.join('static', 'examples_context.txt') @app.route('/get_examples', methods=['GET']) def get_examples(): with open(args.examples_path, 'r') as fp: examples = [line.strip() for line in fp.readlines()] return jsonify(examples) @app.route('/single_api', methods=['GET']) def single_api(): t0 = time() single_context = request.args['context'] single_query = request.args['query'] # start_time = time() outs = context_query_to_logit(single_context, single_query) # logger.info(f'single to logit {time()-start_time}') t1 = time() out = {'ret': outs, 'time': int(1000 * (t1 - t0))} return jsonify(out) @app.route('/batch_api', methods=['POST']) def batch_api(): t0 = time() batch_context = json.loads(request.form['context']) batch_query = json.loads(request.form['query']) # start_time = time() outs = context_query_to_answer(batch_context, batch_query) # logger.info(f'single to logit {time()-start_time}') t1 = time() out = {'ret': outs, 'time': int(1000 * (t1 - t0))} return jsonify(out) logger.info( f'Starting BertEncoder server at {self.get_address(bert_port)}') http_server = HTTPServer(WSGIContainer(app)) http_server.listen(bert_port) IOLoop.instance().start() def get_address(self, port): assert self.base_ip is not None and len(port) > 0 return self.base_ip + ':' + port def embed_query(self, batch_query): emb_session = FuturesSession() r = emb_session.post(self.get_address(self.query_port) + '/query2vec_api', data={'query': json.dumps(batch_query)}) def map_(): result = r.result() emb = result.json() return emb, result.elapsed.total_seconds() * 1000 return map_ def query(self, query): params = {'query': query} res = requests.get(self.get_address(self.index_port) + '/api', params=params) if res.status_code != 200: logger.info('Wrong behavior %d' % res.status_code) try: outs = json.loads(res.text) except Exception as e: logger.info(f'no response or error for q {query}') logger.info(res.text) return outs def batch_query(self, batch_query, batch_context=None, max_answer_length=20, top_k=10, nprobe=64): post_data = { 'query': json.dumps(batch_query), 'context': json.dumps(batch_context) if batch_context is not None else json.dumps(batch_query), 'max_answer_length': max_answer_length, 'top_k': top_k, 'nprobe': nprobe, } res = requests.post(self.get_address(self.index_port) + '/batch_api', data=post_data) if res.status_code != 200: logger.info('Wrong behavior %d' % res.status_code) try: outs = json.loads(res.text) except Exception as e: logger.info(f'no response or error for q {batch_query}') logger.info(res.text) return outs def eval_request(self, args): # Load dataset qids, questions, answers, _ = load_qa_pairs(args.test_path, args) # Run batch_query and evaluate step = args.eval_batch_size predictions = [] evidences = [] titles = [] scores = [] all_tokens = [] start_time = None num_q = 0 for q_idx in tqdm(range(0, len(questions), step)): if q_idx >= 5 * step: # exclude warmup if start_time is None: start_time = time() num_q += len(questions[q_idx:q_idx + step]) result = self.batch_query( questions[q_idx:q_idx + step], max_answer_length=args.max_answer_length, top_k=args.top_k, nprobe=args.nprobe, ) prediction = [[ret['answer'] for ret in out] if len(out) > 0 else [''] for out in result['ret']] evidence = [[ret['context'] for ret in out] if len(out) > 0 else [''] for out in result['ret']] title = [[ret['title'] for ret in out] if len(out) > 0 else [''] for out in result['ret']] score = [[ret['score'] for ret in out] if len(out) > 0 else [-1e10] for out in result['ret']] q_tokens = [ out[0]['query_tokens'] if len(out) > 0 else '' for out in result['ret'] ] predictions += prediction evidences += evidence titles += title scores += score latency = time() - start_time logger.info( f'{time()-start_time:.3f} sec for {num_q} questions => {num_q/(time()-start_time):.1f} Q/Sec' ) eval_fn = evaluate_results if not args.is_kilt else evaluate_results_kilt eval_fn( predictions, qids, questions, answers, args, evidences=evidences, scores=scores, titles=titles, )
def load_qa_pairs(data_path, args, q_idx=None, draft_num_examples=100, shuffle=False): q_ids = [] questions = [] answers = [] titles = [] data = json.load(open(data_path))['data'] for data_idx, item in enumerate(data): if q_idx is not None: if data_idx != q_idx: continue q_id = item['id'] if 'origin' in item: q_id = item['origin'].split('.')[0] + '-' + q_id question = item['question'] if '[START_ENT]' in question: question = question[max(question.index('[START_ENT]') - 300, 0):question.index('[END_ENT]') + 300] answer = item['answers'] title = item.get('titles', ['']) if len(answer) == 0: continue q_ids.append(q_id) questions.append(question) answers.append(answer) titles.append(title) questions = [ query[:-1] if query.endswith('?') else query for query in questions ] # questions = [query.lower() for query in questions] # force lower query if args.do_lower_case: logger.info(f'Lowercasing queries') questions = [query.lower() for query in questions] if shuffle: qa_pairs = list(zip(q_ids, questions, answers, titles)) random.shuffle(qa_pairs) q_ids, questions, answers, titles = zip(*qa_pairs) logger.info(f'Shuffling QA pairs') if args.draft: q_ids = np.array(q_ids)[:draft_num_examples].tolist() questions = np.array(questions)[:draft_num_examples].tolist() answers = np.array(answers)[:draft_num_examples].tolist() titles = np.array(titles)[:draft_num_examples].tolist() if args.truecase: try: global truecase if truecase is None: logger.info('loading truecaser') truecase = TrueCaser( os.path.join(os.environ['DATA_DIR'], args.truecase_path)) logger.info('Truecasing queries') questions = [ truecase.get_true_case(query) if query == query.lower() else query for query in questions ] except Exception as e: print(e) logger.info(f'Loading {len(questions)} questions from {data_path}') logger.info( f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}, Title: {titles[0]}' ) return q_ids, questions, answers, titles