Ejemplo n.º 1
0
 def __init__(self, args):
     self.args = args
     self.base_ip = args.base_ip
     self.query_port = args.query_port
     self.index_port = args.index_port
     self.truecase = TrueCaser(
         os.path.join(os.environ['DATA_DIR'], args.truecase_path))
Ejemplo n.º 2
0
def load_qa_pairs(data_path, args, draft_num_examples=1000, shuffle=False):
    q_ids = []
    questions = []
    answers = []
    data = json.load(open(data_path))['data']
    for item in data:
        q_id = item['id']
        question = item['question']
        answer = item['answers']
        if len(answer) == 0:
            continue
        q_ids.append(q_id)
        questions.append(question)
        answers.append(answer)
    questions = [
        query[:-1] if query.endswith('?') else query for query in questions
    ]

    if args.truecase:
        try:
            logger.info('Loading truecaser for queries')
            truecase = TrueCaser(
                os.path.join(os.environ['DPH_DATA_DIR'], args.truecase_path))
            questions = [
                truecase.get_true_case(query)
                if query == query.lower() else query for query in questions
            ]
        except Exception as e:
            print(e)

    if args.do_lower_case:
        logger.info(f'Lowercasing queries')
        questions = [query.lower() for query in questions]

    if args.draft:
        q_ids = np.array(q_ids)[:draft_num_examples].tolist()
        questions = np.array(questions)[:draft_num_examples].tolist()
        answers = np.array(answers)[:draft_num_examples].tolist()

    if shuffle:
        qa_pairs = list(zip(q_ids, questions, answers))
        random.shuffle(qa_pairs)
        q_ids, questions, answers = zip(*qa_pairs)
        logger.info(f'Shuffling QA pairs')

    logger.info(f'Loading {len(questions)} questions from {data_path}')
    logger.info(f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}')
    return q_ids, questions, answers
Ejemplo n.º 3
0
    def __init__(self,
                 load_dir,
                 dump_dir,
                 index_name='start/1048576_flat_OPQ96',
                 device='cuda',
                 verbose=False,
                 **kwargs):
        print(
            "This could take up to 15 mins depending on the file reading speed of HDD/SSD"
        )

        # Turn off loggers
        if not verbose:
            logging.getLogger("densephrases").setLevel(logging.WARNING)
            logging.getLogger("transformers").setLevel(logging.WARNING)

        # Get default options
        options = Options()
        options.add_model_options()
        options.add_index_options()
        options.add_retrieval_options()
        options.add_data_options()
        self.args = options.parse()

        # Set options
        self.args.load_dir = load_dir
        self.args.dump_dir = dump_dir
        self.args.cache_dir = os.environ['CACHE_DIR']
        self.args.index_name = index_name
        self.args.cuda = True if device == 'cuda' else False
        self.args.__dict__.update(kwargs)

        # Load encoder
        self.set_encoder(load_dir, device)

        # Load MIPS
        self.mips = load_phrase_index(self.args, ignore_logging=not verbose)

        # Others
        self.truecase = TrueCaser(
            os.path.join(os.environ['DATA_DIR'], self.args.truecase_path))
        print("Loading DensePhrases Completed!")
Ejemplo n.º 4
0
class DensePhrases(object):
    def __init__(self,
                 load_dir,
                 dump_dir,
                 index_name='start/1048576_flat_OPQ96',
                 device='cuda',
                 verbose=False,
                 **kwargs):
        print(
            "This could take up to 15 mins depending on the file reading speed of HDD/SSD"
        )

        # Turn off loggers
        if not verbose:
            logging.getLogger("densephrases").setLevel(logging.WARNING)
            logging.getLogger("transformers").setLevel(logging.WARNING)

        # Get default options
        options = Options()
        options.add_model_options()
        options.add_index_options()
        options.add_retrieval_options()
        options.add_data_options()
        self.args = options.parse()

        # Set options
        self.args.load_dir = load_dir
        self.args.dump_dir = dump_dir
        self.args.cache_dir = os.environ['CACHE_DIR']
        self.args.index_name = index_name
        self.args.cuda = True if device == 'cuda' else False
        self.args.__dict__.update(kwargs)

        # Load encoder
        self.set_encoder(load_dir, device)

        # Load MIPS
        self.mips = load_phrase_index(self.args, ignore_logging=not verbose)

        # Others
        self.truecase = TrueCaser(
            os.path.join(os.environ['DATA_DIR'], self.args.truecase_path))
        print("Loading DensePhrases Completed!")

    def search(self,
               query='',
               retrieval_unit='phrase',
               top_k=10,
               truecase=True,
               return_meta=False):
        # If query is str, single query
        single_query = False
        if type(query) == str:
            batch_query = [query]
            single_query = True
        else:
            assert type(query) == list
            batch_query = query

        # Pre-processing
        if truecase:
            query = [
                self.truecase.get_true_case(query)
                if query == query.lower() else query for query in batch_query
            ]

        # Get question vector
        outs = self.query2vec(batch_query)
        start = np.concatenate([out[0] for out in outs], 0)
        end = np.concatenate([out[1] for out in outs], 0)
        query_vec = np.concatenate([start, end], 1)

        # Search
        agg_strats = {
            'phrase': 'opt1',
            'sentence': 'opt2',
            'paragraph': 'opt2',
            'document': 'opt3'
        }
        if retrieval_unit not in agg_strats:
            raise NotImplementedError(
                f'"{retrieval_unit}" not supported. Choose one of {agg_strats.keys()}.'
            )
        search_top_k = top_k
        if retrieval_unit in ['sentece', 'paragraph', 'document']:
            search_top_k *= 2
        rets = self.mips.search(
            query_vec,
            q_texts=batch_query,
            nprobe=256,
            top_k=search_top_k,
            max_answer_length=10,
            return_idxs=False,
            aggregate=True,
            agg_strat=agg_strats[retrieval_unit],
            return_sent=True if retrieval_unit == 'sentence' else False)

        # Gather results
        rets = [ret[:top_k] for ret in rets]
        if retrieval_unit == 'phrase':
            retrieved = [[rr['answer'] for rr in ret][:top_k] for ret in rets]
        elif retrieval_unit == 'sentence':
            retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets]
        elif retrieval_unit == 'paragraph':
            retrieved = [[rr['context'] for rr in ret][:top_k] for ret in rets]
        elif retrieval_unit == 'document':
            retrieved = [[rr['title'][0] for rr in ret][:top_k]
                         for ret in rets]
        else:
            raise NotImplementedError()

        if single_query:
            rets = rets[0]
            retrieved = retrieved[0]

        if return_meta:
            return retrieved, rets
        else:
            return retrieved

    def set_encoder(self, load_dir, device='cuda'):
        self.args.load_dir = load_dir
        self.model, self.tokenizer, self.config = load_encoder(
            device, self.args)
        self.query2vec = get_query2vec(query_encoder=self.model,
                                       tokenizer=self.tokenizer,
                                       args=self.args,
                                       batch_size=64)

    def evaluate(self, test_path, **kwargs):
        from eval_phrase_retrieval import evaluate as evaluate_fn

        # Set new arguments
        new_args = copy.deepcopy(self.args)
        new_args.test_path = test_path
        new_args.truecase = True
        new_args.__dict__.update(kwargs)

        # Run with new_arg
        evaluate_fn(new_args, self.mips, self.model, self.tokenizer)
Ejemplo n.º 5
0
class DensePhrasesDemo(object):
    def __init__(self, args):
        self.args = args
        self.base_ip = args.base_ip
        self.query_port = args.query_port
        self.index_port = args.index_port
        self.truecase = TrueCaser(
            os.path.join(os.environ['DATA_DIR'], args.truecase_path))

    def serve_query_encoder(self,
                            query_port,
                            args,
                            inmemory=False,
                            batch_size=64,
                            query_encoder=None,
                            tokenizer=None):
        device = 'cuda' if args.cuda else 'cpu'
        if query_encoder is None:
            query_encoder, tokenizer, _ = load_encoder(device, args)
        query2vec = get_query2vec(query_encoder=query_encoder,
                                  tokenizer=tokenizer,
                                  args=args,
                                  batch_size=batch_size)

        # Serve query encoder
        app = Flask(__name__)
        app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
        CORS(app)

        @app.route('/query2vec_api', methods=['POST'])
        def query2vec_api():
            batch_query = json.loads(request.form['query'])
            start_time = time()
            outs = query2vec(batch_query)
            # logger.info(f'query2vec {time()-start_time:.3f} for {len(batch_query)} queries: {batch_query[0]}')
            return jsonify(outs)

        logger.info(
            f'Starting QueryEncoder server at {self.get_address(query_port)}')
        http_server = HTTPServer(WSGIContainer(app))
        http_server.listen(query_port)
        IOLoop.instance().start()

    def serve_phrase_index(self, index_port, args):
        args.examples_path = os.path.join('densephrases/demo/static',
                                          args.examples_path)

        # Load mips
        mips = load_phrase_index(args)
        app = Flask(__name__, static_folder='./densephrases/demo/static/')
        app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
        CORS(app)

        def batch_search(batch_query,
                         max_answer_length=20,
                         top_k=10,
                         nprobe=64,
                         return_idxs=False):
            t0 = time()
            outs, _ = self.embed_query(batch_query)()
            start = np.concatenate([out[0] for out in outs], 0)
            end = np.concatenate([out[1] for out in outs], 0)
            query_vec = np.concatenate([start, end], 1)

            rets = mips.search(
                query_vec,
                q_texts=batch_query,
                nprobe=nprobe,
                top_k=top_k,
                max_answer_length=max_answer_length,
                return_idxs=return_idxs,
                aggregate=True,
            )
            for ret_idx, ret in enumerate(rets):
                for rr in ret:
                    rr['query_tokens'] = outs[ret_idx][2]
            t1 = time()
            out = {'ret': rets, 'time': int(1000 * (t1 - t0))}
            return out

        @app.route('/')
        def index():
            return app.send_static_file('index.html')

        @app.route('/files/<path:path>')
        def static_files(path):
            return app.send_static_file('files/' + path)

        # This one uses a default hyperparameters (for Demo)
        @app.route('/api', methods=['GET'])
        def api():
            query = request.args['query']
            query = query[:-1] if query.endswith('?') else query
            if args.truecase:
                if query[1:].lower() == query[1:]:
                    query = self.truecase.get_true_case(query)
            out = batch_search(
                [query],
                max_answer_length=args.max_answer_length,
                top_k=args.top_k,
                nprobe=args.nprobe,
            )
            out['ret'] = out['ret'][0]
            return jsonify(out)

        @app.route('/batch_api', methods=['POST'])
        def batch_api():
            batch_query = json.loads(request.form['query'])
            max_answer_length = int(request.form['max_answer_length'])
            top_k = int(request.form['top_k'])
            nprobe = int(request.form['nprobe'])
            out = batch_search(
                batch_query,
                max_answer_length=max_answer_length,
                top_k=top_k,
                nprobe=nprobe,
            )
            return jsonify(out)

        @app.route('/get_examples', methods=['GET'])
        def get_examples():
            with open(args.examples_path, 'r') as fp:
                examples = [line.strip() for line in fp.readlines()]
            return jsonify(examples)

        if self.query_port is None:
            logger.info(
                'You must set self.query_port for querying. You can use self.update_query_port() later on.'
            )
        logger.info(f'Starting Index server at {self.get_address(index_port)}')
        http_server = HTTPServer(WSGIContainer(app))
        http_server.listen(index_port)
        IOLoop.instance().start()

    def serve_bert_encoder(self, bert_port, args):
        device = 'cuda' if args.cuda else 'cpu'
        # bert_encoder, tokenizer, _ = load_encoder(device, args) # will be just a bert as query_encoder
        bert_encoder, tokenizer = load_cross_encoder(device, args)
        import binascii

        def float_to_hex(vals):
            strs = []
            # offset = -40.
            # scale = 5.
            minv = min(vals)
            maxv = max(vals)
            for val in vals:
                strs.append('{0:0{1}X}'.format(
                    int(min((val - minv) / (maxv - minv) * 255, 255)), 2))
            return strs

        # Define query to vector function
        def context_query_to_logit(context, query):
            bert_encoder.eval()

            # Phrase encoding style
            dataloader, examples, features = get_cq_dataloader(
                [context], [query],
                tokenizer,
                args.max_query_length,
                batch_size=64)
            cq_results = get_cq_results(examples,
                                        features,
                                        dataloader,
                                        device,
                                        bert_encoder,
                                        batch_size=64)

            outs = []
            for cq_idx, cq_result in enumerate(cq_results):
                # import pdb; pdb.set_trace()
                all_logits = (
                    np.expand_dims(np.array(cq_result.start_logits), axis=1) +
                    np.expand_dims(np.array(cq_result.end_logits),
                                   axis=0)).max(1).tolist()
                out = {
                    'context':
                    ' '.join(features[cq_idx].tokens[0:]),
                    'title':
                    'dummy',
                    'start_logits':
                    float_to_hex(all_logits[0:len(features[cq_idx].tokens)]),
                    'end_logits':
                    float_to_hex(
                        cq_result.end_logits[0:len(features[cq_idx].tokens)]),
                }
                outs.append(out)

            return outs

        def context_query_to_answer(batch_context, batch_query):
            bert_encoder.eval()

            # Phrase encoding style
            dataloader, examples, features = get_bertqa_dataloader(
                batch_context,
                batch_query,
                tokenizer,
                args.max_query_length,
                batch_size=64)
            cq_results = get_bertqa_results(examples,
                                            features,
                                            dataloader,
                                            device,
                                            bert_encoder,
                                            batch_size=64)
            predictions, stat = compute_predictions_logits(
                examples,
                features,
                cq_results,
                20,
                10,
                False,
                '',
                '',
                '',
                False,
                False,
                0.0,
                tokenizer,
                -1e8,
                '',
            )

            return predictions

        # Serve query encoder
        app = Flask(__name__)
        app.config['JSONIFY_PRETTYPRINT_REGULAR'] = False
        CORS(app)

        @app.route('/')
        def index():
            return app.send_static_file('index_single.html')

        @app.route('/files/<path:path>')
        def static_files(path):
            return app.send_static_file('files/' + path)

        args.examples_path = os.path.join('static', 'examples_context.txt')

        @app.route('/get_examples', methods=['GET'])
        def get_examples():
            with open(args.examples_path, 'r') as fp:
                examples = [line.strip() for line in fp.readlines()]
            return jsonify(examples)

        @app.route('/single_api', methods=['GET'])
        def single_api():
            t0 = time()
            single_context = request.args['context']
            single_query = request.args['query']
            # start_time = time()
            outs = context_query_to_logit(single_context, single_query)
            # logger.info(f'single to logit {time()-start_time}')
            t1 = time()
            out = {'ret': outs, 'time': int(1000 * (t1 - t0))}
            return jsonify(out)

        @app.route('/batch_api', methods=['POST'])
        def batch_api():
            t0 = time()
            batch_context = json.loads(request.form['context'])
            batch_query = json.loads(request.form['query'])
            # start_time = time()
            outs = context_query_to_answer(batch_context, batch_query)
            # logger.info(f'single to logit {time()-start_time}')
            t1 = time()
            out = {'ret': outs, 'time': int(1000 * (t1 - t0))}
            return jsonify(out)

        logger.info(
            f'Starting BertEncoder server at {self.get_address(bert_port)}')
        http_server = HTTPServer(WSGIContainer(app))
        http_server.listen(bert_port)
        IOLoop.instance().start()

    def get_address(self, port):
        assert self.base_ip is not None and len(port) > 0
        return self.base_ip + ':' + port

    def embed_query(self, batch_query):
        emb_session = FuturesSession()
        r = emb_session.post(self.get_address(self.query_port) +
                             '/query2vec_api',
                             data={'query': json.dumps(batch_query)})

        def map_():
            result = r.result()
            emb = result.json()
            return emb, result.elapsed.total_seconds() * 1000

        return map_

    def query(self, query):
        params = {'query': query}
        res = requests.get(self.get_address(self.index_port) + '/api',
                           params=params)
        if res.status_code != 200:
            logger.info('Wrong behavior %d' % res.status_code)
        try:
            outs = json.loads(res.text)
        except Exception as e:
            logger.info(f'no response or error for q {query}')
            logger.info(res.text)
        return outs

    def batch_query(self,
                    batch_query,
                    batch_context=None,
                    max_answer_length=20,
                    top_k=10,
                    nprobe=64):
        post_data = {
            'query':
            json.dumps(batch_query),
            'context':
            json.dumps(batch_context)
            if batch_context is not None else json.dumps(batch_query),
            'max_answer_length':
            max_answer_length,
            'top_k':
            top_k,
            'nprobe':
            nprobe,
        }
        res = requests.post(self.get_address(self.index_port) + '/batch_api',
                            data=post_data)
        if res.status_code != 200:
            logger.info('Wrong behavior %d' % res.status_code)
        try:
            outs = json.loads(res.text)
        except Exception as e:
            logger.info(f'no response or error for q {batch_query}')
            logger.info(res.text)
        return outs

    def eval_request(self, args):
        # Load dataset
        qids, questions, answers, _ = load_qa_pairs(args.test_path, args)

        # Run batch_query and evaluate
        step = args.eval_batch_size
        predictions = []
        evidences = []
        titles = []
        scores = []
        all_tokens = []
        start_time = None
        num_q = 0
        for q_idx in tqdm(range(0, len(questions), step)):
            if q_idx >= 5 * step:  # exclude warmup
                if start_time is None:
                    start_time = time()
                num_q += len(questions[q_idx:q_idx + step])
            result = self.batch_query(
                questions[q_idx:q_idx + step],
                max_answer_length=args.max_answer_length,
                top_k=args.top_k,
                nprobe=args.nprobe,
            )
            prediction = [[ret['answer']
                           for ret in out] if len(out) > 0 else ['']
                          for out in result['ret']]
            evidence = [[ret['context']
                         for ret in out] if len(out) > 0 else ['']
                        for out in result['ret']]
            title = [[ret['title'] for ret in out] if len(out) > 0 else ['']
                     for out in result['ret']]
            score = [[ret['score']
                      for ret in out] if len(out) > 0 else [-1e10]
                     for out in result['ret']]
            q_tokens = [
                out[0]['query_tokens'] if len(out) > 0 else ''
                for out in result['ret']
            ]
            predictions += prediction
            evidences += evidence
            titles += title
            scores += score
        latency = time() - start_time
        logger.info(
            f'{time()-start_time:.3f} sec for {num_q} questions => {num_q/(time()-start_time):.1f} Q/Sec'
        )
        eval_fn = evaluate_results if not args.is_kilt else evaluate_results_kilt
        eval_fn(
            predictions,
            qids,
            questions,
            answers,
            args,
            evidences=evidences,
            scores=scores,
            titles=titles,
        )
Ejemplo n.º 6
0
def load_qa_pairs(data_path,
                  args,
                  q_idx=None,
                  draft_num_examples=100,
                  shuffle=False):
    q_ids = []
    questions = []
    answers = []
    titles = []
    data = json.load(open(data_path))['data']
    for data_idx, item in enumerate(data):
        if q_idx is not None:
            if data_idx != q_idx:
                continue
        q_id = item['id']
        if 'origin' in item:
            q_id = item['origin'].split('.')[0] + '-' + q_id
        question = item['question']
        if '[START_ENT]' in question:
            question = question[max(question.index('[START_ENT]') -
                                    300, 0):question.index('[END_ENT]') + 300]
        answer = item['answers']
        title = item.get('titles', [''])
        if len(answer) == 0:
            continue
        q_ids.append(q_id)
        questions.append(question)
        answers.append(answer)
        titles.append(title)
    questions = [
        query[:-1] if query.endswith('?') else query for query in questions
    ]
    # questions = [query.lower() for query in questions] # force lower query

    if args.do_lower_case:
        logger.info(f'Lowercasing queries')
        questions = [query.lower() for query in questions]

    if shuffle:
        qa_pairs = list(zip(q_ids, questions, answers, titles))
        random.shuffle(qa_pairs)
        q_ids, questions, answers, titles = zip(*qa_pairs)
        logger.info(f'Shuffling QA pairs')

    if args.draft:
        q_ids = np.array(q_ids)[:draft_num_examples].tolist()
        questions = np.array(questions)[:draft_num_examples].tolist()
        answers = np.array(answers)[:draft_num_examples].tolist()
        titles = np.array(titles)[:draft_num_examples].tolist()

    if args.truecase:
        try:
            global truecase
            if truecase is None:
                logger.info('loading truecaser')
                truecase = TrueCaser(
                    os.path.join(os.environ['DATA_DIR'], args.truecase_path))
            logger.info('Truecasing queries')
            questions = [
                truecase.get_true_case(query)
                if query == query.lower() else query for query in questions
            ]
        except Exception as e:
            print(e)

    logger.info(f'Loading {len(questions)} questions from {data_path}')
    logger.info(
        f'Sample Q ({q_ids[0]}): {questions[0]}, A: {answers[0]}, Title: {titles[0]}'
    )
    return q_ids, questions, answers, titles