Пример #1
0
    def csv_generator(self) -> Generator:
        """Check for the csv in the current working directory first, then
        search for it in the package.

        Generates id_col, passage
        """

        cwd_path = Path().joinpath(self.file).absolute()
        pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file)

        if cwd_path.exists():
            path = cwd_path
        elif pkg_path.exists():
            path = pkg_path
        else:
            self.logger.error('Could not find %s or %s', pkg_path, cwd_path)
            raise SystemExit

        self.logger.info('Estimating completion size...')
        num_lines = count_lines(path)
        with path.open() as file:
            with tqdm(total=num_lines, desc=path.name) as pbar:
                for cid, passage in csv.reader(file, delimiter=self.delim):
                    yield cid, passage
                    pbar.update()
Пример #2
0
    def csv_generator(self) -> Generator:
        """Check for the csv in the current working directory first, then
        search for it in the package.

        Generates id_col and dict of {<column name>: <column value>}
        """

        cwd_path = Path().joinpath(self.file).absolute()
        pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file)

        if cwd_path.exists():
            path = cwd_path
        elif pkg_path.exists():
            path = pkg_path
        else:
            self.logger.error('Could not find %s or %s', pkg_path, cwd_path)
            raise SystemExit

        self.logger.info('Estimating completion size...')
        num_lines = count_lines(path)
        with path.open() as file:
            with tqdm(total=num_lines, desc=path.name) as pbar:
                for line in csv.DictReader(file, delimiter=self.delim):
                    cid = None

                    if self.id_col:
                        cid = line.popitem(last=False)[1]

                    yield cid, dict(line)

                    pbar.update()
Пример #3
0
 def __init__(self,
              lr: float = 10e-3,
              model_dir: str = 'bert-base-uncased-msmarco',
              data_dir: Path = PKG_PATH.joinpath('.cache'),
              max_seq_len: int = 128,
              batch_size: int = 4,
              **_):
     super().__init__()
     self.lr = lr
     self.max_seq_len = max_seq_len
     self.batch_size = batch_size
     self.data_dir = data_dir
     if not os.path.exists(model_dir):
         self.model_dir = data_dir.joinpath(model_dir).absolute()
     else:
         self.model_dir = Path(model_dir)
     self.logger = set_logger(model_dir)
Пример #4
0
def set_parser() -> ArgumentParser:
    """Add default nboost cli arguments to a given parser"""
    parser = ArgumentParser(description=DESCRIPTION)
    parser.add_argument('--verbose', action='store_true', default=False, help=VERBOSE)
    parser.add_argument('--host', type=str, default='0.0.0.0', help=HOST)
    parser.add_argument('--port', type=int, default=8000, help=PORT)
    parser.add_argument('--uhost', type=str, default='0.0.0.0', help=UHOST)
    parser.add_argument('--uport', type=int, default=9200, help=UPORT)
    parser.add_argument('--lr', type=float, default=10e-3, help=LR)
    parser.add_argument('--model_dir', type=str, default='bert-base-uncased-msmarco', help=MODEL_DIR)
    parser.add_argument('--data_dir', type=Path, default=PKG_PATH.joinpath('.cache'), help=DATA_DIR)
    parser.add_argument('--max_seq_len', type=int, default=64, help=MAX_SEQ_LEN)
    parser.add_argument('--bufsize', type=int, default=2048, help=BUFSIZE)
    parser.add_argument('--batch_size', type=int, default=4, help=BATCH_SIZE)
    parser.add_argument('--multiplier', type=int, default=5, help=MULTIPLIER)
    parser.add_argument('--workers', type=int, default=10, help=WORKERS)
    parser.add_argument('--codex', type=lambda x: import_class('codex', x), default='ESCodex', help=PROTOCOL)
    parser.add_argument('--model', type=lambda x: import_class('model', x), default='BertModel', help=MODEL)
    return parser
Пример #5
0
    def csv_generator(self) -> Generator:
        """yield the `--id_col` and `--field_col` from the `--file` csv"""
        pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file)
        cwd_path = Path().joinpath(self.file).absolute()

        if pkg_path.exists():
            path = pkg_path
        elif cwd_path.exists():
            path = cwd_path
        else:
            self.logger.error('Could not find %s or %s', pkg_path, cwd_path)
            raise SystemExit

        self.logger.info('Estimating completion size...')
        num_lines = count_lines(path)
        with path.open() as file:
            with tqdm(total=num_lines, desc=path.name) as pbar:
                for line in csv.reader(file, delimiter=self.delim):
                    pbar.update()
                    yield line[self.id_col], line[self.field_col]
Пример #6
0
    def __init__(self,
                 data_dir: Path = PKG_PATH.joinpath('.cache'),
                 model_dir: str = 'pt-bert-base-uncased-msmarco',
                 qa_model_dir: str = 'distilbert-base-uncased-distilled-squad',
                 qa_model: str = str(),
                 model: str = str(),
                 qa: bool = False,
                 config: str = 'elasticsearch',
                 verbose: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.qa = qa
        self.data_dir = data_dir
        self.data_dir.mkdir(parents=True, exist_ok=True)

        self.model_dir = data_dir.joinpath(model_dir).absolute()
        self.model = self.resolve_model(self.model_dir,
                                        model,
                                        verbose=verbose,
                                        **kwargs)
        self.logger = set_logger(self.model.__class__.__name__,
                                 verbose=verbose)

        if qa:
            self.qa_model_dir = data_dir.joinpath(qa_model_dir).absolute()
            self.qa_model = self.resolve_model(self.qa_model_dir, qa_model,
                                               **kwargs)

        # these are global parameters that are overrided by nboost json key
        self.config = {
            'model': self.model.__class__.__name__,
            'model_dir': model_dir,
            'qa_model': self.qa_model.__class__.__name__ if qa else None,
            'qa_model_dir': qa_model_dir if qa else None,
            'data_dir': str(data_dir),
            **CONFIG_MAP[config],
            **kwargs
        }
Пример #7
0
class MsMarco(Benchmarker):
    """MSMARCO dataset benchmarker"""
    DEFAULT_URL = ('https://msmarco.blob.core.windows.net'
                   '/msmarcoranking/collectionandqueries.tar.gz')
    BASE_DATASET_DIR = PKG_PATH.joinpath('.cache/datasets/ms_marco')

    def __init__(self, args):
        super().__init__(args)
        if not args.url:
            self.url = self.DEFAULT_URL
        else:
            self.url = args.url
        archive_file = self.url.split('/')[-1]
        archive_name = archive_file.split('.')[0]
        self.dataset_dir = self.BASE_DATASET_DIR.joinpath(archive_name)
        self.tar_gz_path = self.dataset_dir.joinpath(archive_file)
        self.qrels_tsv_path = self.dataset_dir.joinpath('qrels.dev.small.tsv')
        self.queries_tsv_path = self.dataset_dir.joinpath('queries.dev.tsv')
        self.collections_tsv_path = self.dataset_dir.joinpath('collection.tsv')
        self.index = 'ms_marco_' + archive_name

        # DOWNLOAD MSMARCO
        if not self.dataset_dir.exists():
            self.dataset_dir.mkdir(parents=True, exist_ok=True)
            self.logger.info('Dowloading MSMARCO to %s' % self.tar_gz_path)
            download_file(self.url, self.tar_gz_path)
            self.logger.info('Extracting MSMARCO')
            extract_tar_gz(self.tar_gz_path, self.dataset_dir)
            self.tar_gz_path.unlink()

        self.proxy_es = Elasticsearch(host=self.args.host,
                                      port=self.args.port,
                                      timeout=REQUEST_TIMEOUT)
        self.direct_es = Elasticsearch(host=self.args.uhost,
                                       port=self.args.uport,
                                       timeout=REQUEST_TIMEOUT)

        collection_size = 0
        with open(self.collections_tsv_path) as collection:
            for _ in collection:
                collection_size += 1

        # INDEX MSMARCO
        try:
            if self.direct_es.count(
                    index=self.index)['count'] < collection_size:
                raise elasticsearch.exceptions.NotFoundError
        except elasticsearch.exceptions.NotFoundError:
            try:
                self.direct_es.indices.create(index=self.index,
                                              body={
                                                  'settings': {
                                                      'index': {
                                                          'number_of_shards':
                                                          args.shards
                                                      }
                                                  }
                                              })
            except:
                pass
            self.logger.info('Indexing %s' % self.collections_tsv_path)
            es_bulk_index(self.direct_es, self.stream_msmarco_full())

        self.logger.info('Reading %s' % self.qrels_tsv_path)
        with self.qrels_tsv_path.open() as file:
            qrels = csv.reader(file, delimiter='\t')
            for qid, _, doc_id, _ in qrels:
                self.add_qrel(qid, doc_id)

        self.logger.info('Reading %s' % self.queries_tsv_path)
        with self.queries_tsv_path.open() as file:
            queries = csv.reader(file, delimiter='\t')
            for qid, query in queries:
                self.add_query(qid, query)

    def stream_msmarco_full(self):
        self.logger.info('Optimizing streamer...')
        num_lines = sum(1 for _ in self.collections_tsv_path.open())
        with self.collections_tsv_path.open() as fh:
            data = csv.reader(fh, delimiter='\t')
            with tqdm(total=num_lines, desc='INDEXING MSMARCO') as pbar:
                for ident, passage in data:
                    body = dict(_index=self.index,
                                _id=ident,
                                _source={'passage': passage})
                    yield body
                    pbar.update()

    def proxied_doc_id_producer(self, query: str):
        return self.es_doc_id_producer(self.proxy_es, query)

    def direct_doc_id_producer(self, query: str):
        return self.es_doc_id_producer(self.direct_es, query)

    def es_doc_id_producer(self, es: Elasticsearch, query: str):
        body = dict(size=self.args.topk,
                    query={"match": {
                        "passage": {
                            "query": query
                        }
                    }})

        res = es.search(index=self.index,
                        body=body,
                        filter_path=['hits.hits._*'])

        doc_ids = [hit['_id'] for hit in res['hits']['hits']]
        return doc_ids
Пример #8
0
    def __init__(
            self,
            host: type(defaults.host) = defaults.host,
            port: type(defaults.port) = defaults.port,
            verbose: type(defaults.verbose) = defaults.verbose,
            data_dir: type(defaults.data_dir) = defaults.data_dir,
            no_rerank: type(defaults.no_rerank) = defaults.no_rerank,
            model: type(defaults.model) = defaults.model,
            model_dir: type(defaults.model_dir) = defaults.model_dir,
            qa: type(defaults.qa) = defaults.qa,
            qa_model: type(defaults.qa_model) = defaults.qa_model,
            qa_model_dir: type(defaults.qa_model_dir) = defaults.qa_model_dir,
            search_route: type(defaults.search_route) = defaults.search_route,
            frontend_route: type(
                defaults.frontend_route) = defaults.frontend_route,
            status_route: type(defaults.status_route) = defaults.status_route,
            debug: type(defaults.debug) = defaults.debug,
            prerank: type(defaults.prerank) = defaults.prerank,
            **cli_args):
        self.logger = set_logger(self.__class__.__name__, verbose=verbose)
        BackwardsCompatibility().set()
        db = Database()
        plugins = []  # type: List[Plugin]

        if prerank:
            preRankPlugin = PrerankPlugin()
            plugins.append(preRankPlugin)

        if not no_rerank:
            rerank_model_plugin = resolve_model(
                data_dir=data_dir,
                model_dir=model_dir,
                model_cls=model,
                **cli_args)  # type: RerankModelPlugin

            plugins.append(rerank_model_plugin)

        if qa:
            qa_model_plugin = resolve_model(data_dir=data_dir,
                                            model_dir=qa_model_dir,
                                            model_cls=qa_model,
                                            **cli_args)  # type: QAModelPlugin

            plugins.append(qa_model_plugin)

        if debug:
            debug_plugin = DebugPlugin(**cli_args)
            plugins.append(debug_plugin)

        static_dir = str(PKG_PATH.joinpath('resources/frontend'))
        flask_app = Flask(__name__)

        @flask_app.route(frontend_route, methods=['GET'])
        def frontend_root():
            return send_from_directory(static_dir, 'index.html')

        @flask_app.route(frontend_route + '/<path:path>', methods=['GET'])
        def frontend_path(path):
            return send_from_directory(static_dir, path)

        @flask_app.route(frontend_route + status_route)
        def status_path():
            configs = {}

            for plugin in plugins:
                configs.update(plugin.configs)

            stats = db.get_stats()
            return jsonify({**configs, **stats})

        flask_app.url_map.add(Rule('/<path:path>', endpoint='proxy'))

        @flask_app.route('/', defaults={'path': ''})
        @flask_app.endpoint('proxy')
        def proxy_through(path):
            # parse the client request
            dict_request = flask_request_to_dict_request(
                flask_request)  # takes the json
            """Search request."""
            db_row = db.new_row()

            # combine command line args and runtime args sent by request
            query_args = {}
            for key in list(dict_request['url']['query']):
                if key in defaults.__dict__:
                    query_args[key] = dict_request['url']['query'].pop(key)
            json_args = dict_request['body'].pop('nboost', {})
            args = {**cli_args, **json_args, **query_args}

            request = RequestDelegate(dict_request, **args)
            request.dict['headers'].pop('Host', '')
            request.set_path('url.headers.host',
                             '%s:%s' % (request.uhost, request.uport))
            request.set_path('url.netloc',
                             '%s:%s' % (request.uhost, request.uport))
            request.set_path('url.scheme', 'https' if request.ussl else 'http')

            for plugin in plugins:  # type: Plugin
                plugin.on_request(request, db_row)

            # get response from upstream server
            start_time = perf_counter()
            requests_response = dict_request_to_requests_response(dict_request)
            db_row.response_time = perf_counter() - start_time
            try:
                dict_response = requests_response_to_dict_response(
                    requests_response)
            except JSONDecodeError:
                print(requests_response.content)
                return requests_response.content
            response = ResponseDelegate(dict_response, request)
            response.set_path('body.nboost', {})
            db_row.choices = len(response.choices)

            for plugin in plugins:  # type: Plugin
                plugin.on_response(response, db_row)

            # save stats to sql lite
            db.insert(db_row)

            return dict_response_to_flask_response(dict_response)

        @flask_app.errorhandler(Exception)
        def handle_json_response(error):
            self.logger.error('', exc_info=True)
            return jsonify({
                'type': error.__class__.__name__,
                'doc': error.__class__.__doc__,
                'msg': str(error.args)
            }), 500

        self.run = lambda: (self.logger.critical('LISTENING %s:%s' % (
            host, port)) or flask_app.run(host=host, port=port))
Пример #9
0
host = '0.0.0.0'
port = 8000
uhost = '0.0.0.0'
uport = 9200
ussl = False
backlog = 100
verbose = False
query_delim = '. '
lr = 10e-3
max_seq_len = 64
bufsize = 2048
batch_size = 4
topn = 50
workers = 10
data_dir = PKG_PATH.joinpath('.cache')
no_rerank = False
model = 'PtTransformersRerankPlugin'
model_dir = 'nboost/pt-tinybert-msmarco'
qa = False
qa_model = 'PtDistilBertQAModelPlugin'
qa_model_dir = 'distilbert-base-uncased-distilled-squad'
qa_threshold = 0
max_query_length = 64
filter_results = False
query_prep = 'lambda query: query'
debug = False
db_file = data_dir.joinpath('nboost.db')
rerank_cids = ListOrCommaDelimitedString()
prerank = False
Пример #10
0
def build():
    """Build dockerfiles"""
    for image, path in IMAGE_MAP.items():
        path = PKG_PATH.joinpath(path).absolute()
        execute(BUILD.format(image=image, path=path))
Пример #11
0
class Proxy(SocketServer):
    """The proxy object is the core of NBoost.
    The following __init__ contains the main executed functions in nboost.

    :param host: virtual host of the server.
    :param port: server port.
    :param uhost: host of the external search api.
    :param uport: search api port.
    :param multiplier: the factor to multiply the search request by. For
        example, in the case of Elasticsearch if the client requests 10
        results and the multiplier is 6, then the model should receive 60
        results to rank and refine down to 10 (better) results.
    :param field: a tag for the field in the search api result that the
        model should rank results by.
    :param model: uninitialized model class
    :param codex: uninitialized codex class
    """

    # statistical contexts
    STATIC_PATH = PKG_PATH.joinpath('resources/frontend')
    stats = ClassStatistics()

    def __init__(self,
                 model: Type[BaseModel],
                 codex: Type[BaseCodex],
                 uhost: str = '0.0.0.0',
                 uport: int = 9200,
                 bufsize: int = 2048,
                 verbose: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.kwargs = kwargs
        self.uaddress = (uhost, uport)
        self.bufsize = bufsize
        self.logger = set_logger(model.__name__, verbose=verbose)

        # pass command line arguments to instantiate each component
        self.model = model(verbose=verbose, **kwargs)
        self.codex = codex(verbose=verbose, **kwargs)

    def on_client_request_url(self, url: URL):
        """Method for screening the url path from the client request"""
        if url.path.startswith('/nboost'):
            raise FrontendRequest

        if not re.match(self.codex.SEARCH_PATH, url.path):
            raise UnknownRequest

    def set_protocol(self, sock: socket.socket) -> HttpProtocol:
        """Construct the protocol with the proxy settings"""
        protocol = HttpProtocol(sock)
        protocol.set_bufsize = self.bufsize
        return protocol

    @stats.time_context
    def proxy_send(self, client_socket, server_socket, buffer: bytearray):
        """Send buffered request to server and receive the rest of the original
        client request"""
        protocol = self.set_protocol(client_socket)
        protocol.set_request_parser()
        server_socket.send(buffer)
        protocol.feed(buffer)
        protocol.add_data_hook(server_socket.send)
        protocol.recv()

    @stats.time_context
    def proxy_recv(self, client_socket, server_socket):
        """Receive the proxied response and pipe to the client"""
        protocol = self.set_protocol(server_socket)
        protocol.set_response_parser()
        protocol.add_data_hook(client_socket.send)
        protocol.recv()

    @stats.time_context
    def client_recv(self, client_socket, request: Request, buffer: bytearray):
        """Receive client request and pipe to buffer in case of exceptions"""
        protocol = self.set_protocol(client_socket)
        protocol.set_request_parser()
        protocol.set_request(request)
        protocol.add_data_hook(buffer.extend)
        protocol.add_url_hook(self.on_client_request_url)
        protocol.recv()

    @staticmethod
    @stats.time_context
    def server_send(server_socket: socket.socket, request: Request):
        """Send magnified request to the server"""
        server_socket.send(request.prepare())

    @stats.time_context
    def server_recv(self, server_socket: socket.socket, response: Response):
        """Receive magnified request from the server"""
        protocol = self.set_protocol(server_socket)
        protocol.set_response_parser()
        protocol.set_response(response)
        protocol.recv()

    @staticmethod
    @stats.time_context
    def client_send(request: Request, response: Response, client_socket):
        """Send the ranked results to the client"""
        raw_response = response.prepare(request)
        client_socket.send(raw_response)

    @stats.time_context
    def model_rank(self, query: bytes, choices: List[Choice]) -> List[int]:
        """Rank the query and choices and return the argsorted indices"""
        return self.model.rank(query, choices)

    @stats.vars_context
    def record_topk_and_choices(self, topk: int = None, choices: list = None):
        """Add topk and choices to the running statistical averages"""

    @stats.vars_context
    def record_mrrs(self, upstream_mrr: float = None, model_mrr: float = None):
        """Add the upstream mrr, model mrr, and search boost to the stats"""
        with suppress(ZeroDivisionError):
            var = self.stats.record['vars']
            var['search_boost'] = {
                'avg': var['model_mrr']['avg'] / var['upstream_mrr']['avg']
            }

    @stats.time_context
    def server_connect(self, server_socket: socket.socket) -> None:
        """Connect proxied server socket"""
        try:
            server_socket.connect(self.uaddress)
        except ConnectionRefusedError:
            raise UpstreamConnectionError(*self.uaddress)

    @property
    def status(self) -> dict:
        """Return status dictionary in the case of a status request"""
        return {
            'multiplier': self.codex.multiplier,
            **self.stats.record, 'description': 'NBoost, for search ranking.'
        }

    def calculate_mrrs(self, correct_cids: List[str], choices: List[Choice],
                       ranks: List[int]):
        """Calculate the mrr of the upstream server and reranked choices from
        the model. This only occurs if the client specified the "nboost"
        parameter in the request url or body."""
        upstream_mrr = self.calculate_mrr(correct_cids, choices)
        reranked_choices = [choices[rank] for rank in ranks]
        model_mrr = self.calculate_mrr(correct_cids, reranked_choices)
        self.record_mrrs(upstream_mrr=upstream_mrr, model_mrr=model_mrr)

    @staticmethod
    def calculate_mrr(correct_cids: List[str], choices: List[Choice]):
        """Calculate mean reciprocal rank as the first correct result index"""
        for i, choice in enumerate(choices, 1):
            if choice.cid in correct_cids:
                return 1 / i
        return 0

    def get_static_file(self, path: str) -> bytes:
        """Construct the static path of the frontend asset requested and return
        the raw file."""
        if path == '/nboost':
            asset = 'index.html'
        else:
            asset = path.replace('/nboost/', '', 1)

        static_path = self.STATIC_PATH.joinpath(asset)

        # for security reasons, make sure there is no access to other dirs
        if self.STATIC_PATH in static_path.parents and static_path.exists():
            return static_path.read_bytes()
        else:
            return self.STATIC_PATH.joinpath('404.html').read_bytes()

    def loop(self, client_socket: socket.socket, address: Tuple[str, str]):
        """Main ioloop for reranking server results to the client. Exceptions
        raised in the http parser must be reraised from __context__ because
        they are caught by the MagicStack implementation"""
        server_socket = self.set_socket()
        buffer = bytearray()
        request = Request()
        response = Response()
        log = ('%s:%s %s', *address, request)

        try:
            self.server_connect(server_socket)

            with HttpParserContext():
                # receive and buffer the client request
                self.client_recv(client_socket, request, buffer)
                self.logger.debug(*log)
                field, query = self.codex.parse_query(request)

                # magnify the size of the request to the server
                topk, correct_cids = self.codex.multiply_request(request)
                self.server_send(server_socket, request)

                # make sure server response comes back properly
                self.server_recv(server_socket, response)
                response.unpack()

            if response.status < 300:
                # parse the choices from the magnified response
                choices = self.codex.parse_choices(response, field)
                self.record_topk_and_choices(topk=topk, choices=choices)

                # use the model to rerank the choices
                ranks = self.model_rank(query, choices)[:topk]
                self.codex.reorder_response(request, response, ranks)

                # if the "nboost" param was sent, calculate MRRs
                if correct_cids is not None:
                    self.calculate_mrrs(correct_cids, choices, ranks)

            self.client_send(request, response, client_socket)

        except FrontendRequest:
            self.logger.info(*log)

            if request.url.path == '/nboost/status':
                response.body = json.dumps(self.status, indent=2).encode()
            else:
                response.body = self.get_static_file(request.url.path)
            self.client_send(request, response, client_socket)

        except (UnknownRequest, MissingQuery):
            self.logger.warning(*log)

            # send the initial buffer that was used to check url path
            self.proxy_send(client_socket, server_socket, buffer)

            # stream the client socket to the server socket
            self.proxy_recv(client_socket, server_socket)

        except Exception as exc:
            # for misc errors, send back json error msg
            self.logger.error(repr(exc), exc_info=True)
            response.body = json.dumps(dict(error=repr(exc))).encode()
            response.status = 500
            self.client_send(request, response, client_socket)

        finally:
            client_socket.close()
            server_socket.close()

    def run(self):
        """Same as socket server run() but logs"""
        self.logger.critical('Upstream host is %s:%s', *self.uaddress)
        super().run()

    def close(self):
        """Close the proxy server and model"""
        self.logger.info('Closing model...')
        self.model.close()
        super().close()
Пример #12
0
class Proxy(SocketServer):
    """The proxy object is the core of NBoost.
    The following __init__ contains the main executed functions in nboost.

    :param host: virtual host of the server.
    :param port: server port.
    :param uhost: host of the external search api.
    :param uport: search api port.
    :param multiplier: the factor to multiply the search request by. For
        example, in the case of Elasticsearch if the client requests 10
        results and the multiplier is 6, then the model should receive 60
        results to rank and refine down to 10 (better) results.
    :param field: a tag for the field in the search api result that the
        model should rank results by.
    :param model: uninitialized model class
    :param codex: uninitialized codex class
    """

    stats = ClassStatistics()
    STATIC_PATH = PKG_PATH.joinpath('resources/frontend')

    def __init__(self,
                 data_dir: Path = PKG_PATH.joinpath('.cache'),
                 model_dir: str = 'pt-bert-base-uncased-msmarco',
                 qa_model_dir: str = 'distilbert-base-uncased-distilled-squad',
                 qa_model: str = str(),
                 model: str = str(),
                 qa: bool = False,
                 config: str = 'elasticsearch',
                 verbose: bool = False,
                 **kwargs):
        super().__init__(**kwargs)
        self.qa = qa
        self.data_dir = data_dir
        self.data_dir.mkdir(parents=True, exist_ok=True)

        self.model_dir = data_dir.joinpath(model_dir).absolute()
        self.model = self.resolve_model(self.model_dir,
                                        model,
                                        verbose=verbose,
                                        **kwargs)
        self.logger = set_logger(self.model.__class__.__name__,
                                 verbose=verbose)

        if qa:
            self.qa_model_dir = data_dir.joinpath(qa_model_dir).absolute()
            self.qa_model = self.resolve_model(self.qa_model_dir, qa_model,
                                               **kwargs)

        # these are global parameters that are overrided by nboost json key
        self.config = {
            'model': self.model.__class__.__name__,
            'model_dir': model_dir,
            'qa_model': self.qa_model.__class__.__name__ if qa else None,
            'qa_model_dir': qa_model_dir if qa else None,
            'data_dir': str(data_dir),
            **CONFIG_MAP[config],
            **kwargs
        }

    def resolve_model(self, model_dir: Path, cls: str, **kwargs):
        """Dynamically import class from a module in the CLASS_MAP. This is used
        to manage dependencies within nboost. For example, you don't necessarily
        want to import pytorch models everytime you boot up tensorflow..."""
        if model_dir.exists():
            self.logger.info('Using model cache from %s', model_dir)

            if model_dir.name in CLASS_MAP:
                cls = CLASS_MAP[model_dir.name]
            elif cls not in MODULE_MAP:
                raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys())

            module = MODULE_MAP[cls]
            model = import_class(module, cls)
            return model(str(model_dir), **kwargs)
        else:
            if model_dir.name in CLASS_MAP:
                cls = CLASS_MAP[model_dir.name]
                module = MODULE_MAP[cls]
                url = URL_MAP[model_dir.name]
                binary_path = self.data_dir.joinpath(Path(url).name)

                if binary_path.exists():
                    self.logger.info('Found model cache in %s', binary_path)
                else:
                    self.logger.info('Downloading "%s" model.', model_dir)
                    download_file(url, binary_path)

                if binary_path.suffixes == ['.tar', '.gz']:
                    self.logger.info('Extracting "%s" from %s', model_dir,
                                     binary_path)
                    extract_tar_gz(binary_path, self.data_dir)

                model = import_class(module, cls)
                return model(str(model_dir), **kwargs)

            else:
                if cls in MODULE_MAP:
                    module = MODULE_MAP[cls]
                    model = import_class(module, cls)
                    return model(model_dir.name, **kwargs)
                else:
                    raise ImportError('model_dir %s not found in %s. You must '
                                      'set --model class to continue.' %
                                      (model_dir.name, CLASS_MAP.keys()))

    def on_client_request_url(self, url: dict):
        """Method for screening the url path from the client request"""
        if url['path'].startswith('/nboost/status'):
            raise StatusRequest

        if url['path'].startswith('/nboost'):
            raise FrontendRequest

        if not re.match(self.config['capture_path'], url['path']):
            raise UnknownRequest

    def get_protocol(self) -> HttpProtocol:
        """Return a configured http protocol parser."""
        return HttpProtocol(self.config['bufsize'])

    @stats.time_context
    def frontend_send(self, client_socket, request):
        """Send a the static frontend to the client."""
        response = {}
        protocol = self.get_protocol()
        protocol.set_response_parser()
        protocol.set_response(response)
        response['body'] = self.get_static_file(request['url']['path'])
        client_socket.send(prepare_response(response))

    @stats.time_context
    def status_send(self, client_socket, request):
        """Send a the static frontend to the client."""
        response = {}
        protocol = self.get_protocol()
        protocol.set_response_parser()
        protocol.set_response(response)
        response['body'] = self.status
        response['body'] = dump_json(response['body'], indent=2)
        client_socket.send(prepare_response(response))

    @stats.time_context
    def proxy_send(self, client_socket, server_socket, buffer: bytearray):
        """Send buffered request to server and receive the rest of the original
        client request"""
        protocol = self.get_protocol()
        protocol.set_request_parser()
        protocol.add_data_hook(server_socket.send)
        protocol.feed(buffer)
        protocol.recv(client_socket)

    @stats.time_context
    def proxy_recv(self, client_socket, server_socket):
        """Receive the proxied response and pipe to the client"""
        protocol = self.get_protocol()
        protocol.set_response_parser()
        protocol.add_data_hook(client_socket.send)
        protocol.recv(server_socket)

    @stats.time_context
    def client_recv(self, client_socket, request: dict, buffer: bytearray):
        """Receive client request and pipe to buffer in case of exceptions"""
        protocol = self.get_protocol()
        protocol.set_request_parser()
        protocol.set_request(request)
        protocol.add_data_hook(buffer.extend)
        protocol.add_url_hook(self.on_client_request_url)
        protocol.recv(client_socket)

    @staticmethod
    @stats.time_context
    def server_send(server_socket: socket.socket, request: dict):
        """Send magnified request to the server"""
        request['body'] = dump_json(request['body'])
        request['headers']['content-type'] = 'application/json; charset=UTF-8'
        server_socket.send(prepare_request(request))

    @stats.time_context
    def server_recv(self, server_socket: socket.socket, response: dict):
        """Receive magnified request from the server"""
        protocol = self.get_protocol()
        protocol.set_response_parser()
        protocol.set_response(response)
        protocol.recv(server_socket)

    @staticmethod
    @stats.time_context
    def client_send(request: dict, response: dict, client_socket):
        """Send the ranked results to the client"""
        kwargs = dict(indent=2) if 'pretty' in request['url']['query'] else {}
        response['body'] = dump_json(response['body'], **kwargs)
        client_socket.send(prepare_response(response))

    @stats.time_context
    def error_send(self, client_socket, exc: Exception):
        """Send internal server error to the client."""
        response = {}
        protocol = self.get_protocol()
        protocol.set_response(response)
        response['body'] = dump_json({'error': repr(exc)}, indent=2)
        response['status'] = 500
        client_socket.send(prepare_response(response))

    @stats.time_context
    def model_rank(self, query: str, choices: List[str]) -> List[int]:
        """Rank the query and choices and return the argsorted indices"""
        return self.model.rank(query, choices)

    @stats.vars_context
    def record_topk_and_choices(self, topk: int = None, choices: list = None):
        """Add topk and choices to the running statistical averages"""

    @stats.vars_context
    def record_mrrs(self, upstream_mrr: float = None, model_mrr: float = None):
        """Add the upstream mrr, model mrr, and search boost to the stats"""
        with suppress(ZeroDivisionError):
            var = self.stats.record['vars']
            var['search_boost'] = {
                'avg': var['model_mrr']['avg'] / var['upstream_mrr']['avg']
            }

    @stats.time_context
    def server_connect(self, server_socket: socket.socket):
        """Connect proxied server socket"""
        uaddress = (self.config['uhost'], self.config['uport'])
        try:
            server_socket.connect(uaddress)
        except ConnectionRefusedError:
            raise UpstreamConnectionError('Connect error for %s:%s' % uaddress)

    @property
    def status(self) -> dict:
        """Return status dictionary in the case of a status request"""
        return {
            **self.config,
            **self.stats.record, 'description': 'NBoost, for search ranking.'
        }

    def calculate_mrrs(self, true_cids: List[str], cids: List[str],
                       ranks: List[int]):
        """Calculate the mrr of the upstream server and reranked choices from
        the model. This only occurs if the client specified the "nboost"
        parameter in the request url or body."""
        upstream_mrr = self.calculate_mrr(true_cids, cids)
        reranked_cids = [cids[rank] for rank in ranks]
        model_mrr = self.calculate_mrr(true_cids, reranked_cids)
        self.record_mrrs(upstream_mrr=upstream_mrr, model_mrr=model_mrr)

    @staticmethod
    def calculate_mrr(correct: list, guesses: list):
        """Calculate mean reciprocal rank as the first correct result index"""
        for i, guess in enumerate(guesses, 1):
            if guess in correct:
                return 1 / i
        return 0

    def get_static_file(self, path: str) -> bytes:
        """Construct the static path of the frontend asset requested and return
        the raw file."""
        if path == '/nboost':
            asset = 'index.html'
        else:
            asset = path.replace('/nboost/', '', 1)

        static_path = self.STATIC_PATH.joinpath(asset)

        # for security reasons, make sure there is no access to other dirs
        if self.STATIC_PATH in static_path.parents and static_path.exists():
            return static_path.read_bytes()
        else:
            return self.STATIC_PATH.joinpath('404.html').read_bytes()

    @staticmethod
    def get_request_paths(request, configs) -> Tuple[str, int, list]:
        """Get the request jsonpaths noted in the configs"""
        queries = get_jsonpath(request, configs['query_path'])
        topks = get_jsonpath(request, configs['topk_path'])
        true_cids = get_jsonpath(request, configs['true_cids_path'])

        # coerce request variables from their paths
        topk = int(topks[0]) if topks else configs['default_topk']
        query = configs['delim'].join(queries)

        # check for errors
        if not query:
            raise MissingQuery

        return query, topk, true_cids

    @staticmethod
    def get_response_paths(response, configs) -> Tuple[list, list, list]:
        """Get the request jsonpaths noted in the configs"""
        choices = get_jsonpath(response, configs['choices_path'])

        if not isinstance(choices, list):
            raise InvalidChoices('choices were not a list')

        choices = flatten(choices)
        cids = get_jsonpath(choices, '[*].' + configs['cids_path'])
        cvalues = get_jsonpath(choices, '[*].' + configs['cvalues_path'])

        # check for errors
        if not len(choices) == len(cids) == len(cvalues):
            raise InvalidChoices('number of choices, cids, and cvalues differ')

        return choices, cids, cvalues

    def loop(self, client_socket: socket.socket, address: Tuple[str, str]):
        """Main ioloop for reranking server results to the client. Exceptions
        raised in the http parser must be reraised from __context__ because
        they are caught by the MagicStack implementation"""
        buffer = bytearray()
        server_socket = self.set_socket()
        request = {}
        response = {}

        try:
            self.server_connect(server_socket)

            with HttpParserContext():
                # receive and buffer the client request
                self.client_recv(client_socket, request, buffer)
                self.logger.debug('Request (%s:%s): search.', *address)

                # combine runtime configs and preset configs
                configs = {**self.config, **request['body'].get('nboost', {})}

                query, topk, true_cids = self.get_request_paths(
                    request, configs)

                # magnify the size of the request to the server
                new_topk = topk * configs['multiplier']
                set_jsonpath(request, configs['topk_path'], new_topk)

                # send the magnified request to the upstream server
                self.server_send(server_socket, request)
                self.server_recv(server_socket, response)

            if response['status'] < 300:
                choices, cids, cvalues = self.get_response_paths(
                    response, configs)

                self.record_topk_and_choices(topk=topk, choices=choices)

                # use the model to rerank the choices
                ranks = self.model_rank(query, cvalues)[:topk]
                reranked = [choices[rank] for rank in ranks]
                set_jsonpath(response, configs['choices_path'], reranked)

                # if the "nboost" param was sent, calculate MRRs
                if true_cids is not None:
                    self.calculate_mrrs(true_cids, cids, ranks)

                response['body']['nboost'] = {}

                if self.qa and len(cvalues) > 0:
                    answer, offsets, score = self.qa_model.get_answer(
                        query, cvalues[ranks.index(min(ranks))])
                    response['body']['nboost']['qa_model'] = answer
                    response['body']['nboost']['qa_model_offsets'] = offsets
                    response['body']['nboost']['qa_model_score'] = score

            self.client_send(request, response, client_socket)

        except FrontendRequest:
            self.logger.info('Request (%s:%s): frontend request', *address)
            self.frontend_send(client_socket, request)

        except StatusRequest:
            self.logger.info('Request (%s:%s): status request', *address)
            self.status_send(client_socket, request)

        except UnknownRequest:
            self.logger.info('Request (%s:%s): unknown path "%s"', *address,
                             request['url']['path'])
            self.proxy_send(client_socket, server_socket, buffer)
            self.proxy_recv(client_socket, server_socket)

        except MissingQuery:
            self.logger.warning('Request (%s:%s): missing query', *address)
            self.proxy_send(client_socket, server_socket, buffer)
            self.proxy_recv(client_socket, server_socket)

        except InvalidChoices as exc:
            self.logger.warning('Request (%s:%s): %s', *address, *exc.args)
            self.proxy_send(client_socket, server_socket, buffer)
            self.proxy_recv(client_socket, server_socket)

        except Exception as exc:
            # for misc errors, send back json error msg
            self.logger.error('Request (%s:%s): %s.',
                              *address,
                              exc,
                              exc_info=True)
            self.error_send(client_socket, exc)

        finally:
            client_socket.close()
            server_socket.close()

    def run(self):
        """Same as socket server run() but logs"""
        self.logger.info(dump_json(self.config, indent=4).decode())
        super().run()

    def close(self):
        """Close the proxy server and model"""
        self.logger.info('Closing model...')
        self.model.close()
        super().close()
Пример #13
0
def set_parser() -> ArgumentParser:
    """Add default nboost cli arguments to a given parser"""
    parser = ArgumentParser(prog='nboost',
                            description=DESCRIPTION,
                            formatter_class=lambda prog: AdHf(
                                prog, max_help_position=100, width=100))
    parser.add_argument('--capture_path',
                        type=str,
                        default=SUPPRESS,
                        help=CAPTURE_PATH)
    parser.add_argument('--query_path',
                        type=str,
                        default=SUPPRESS,
                        help=QUERY_PATH)
    parser.add_argument('--topk_path',
                        type=str,
                        default=SUPPRESS,
                        help=TOPK_PATH)
    parser.add_argument('--cvalues_path',
                        type=str,
                        default=SUPPRESS,
                        help=CVALUES_PATH)
    parser.add_argument('--cids_path',
                        type=str,
                        default=SUPPRESS,
                        help=CIDS_PATH)
    parser.add_argument('--true_cids_path',
                        type=str,
                        default=SUPPRESS,
                        help=TRUE_CIDS_PATH)
    parser.add_argument('--choices_path',
                        type=str,
                        default=SUPPRESS,
                        help=CHOICES_PATH)
    parser.add_argument('--verbose', type=bool, default=False, help=VERBOSE)
    parser.add_argument('--host', type=str, default='0.0.0.0', help=HOST)
    parser.add_argument('--port', type=int, default=8000, help=PORT)
    parser.add_argument('--uhost', type=str, default='0.0.0.0', help=UHOST)
    parser.add_argument('--uport', type=int, default=9200, help=UPORT)
    parser.add_argument('--delim', type=str, default='. ', help=DELIM)
    parser.add_argument('--lr', type=float, default=10e-3, help=LR)
    parser.add_argument('--max_seq_len',
                        type=int,
                        default=64,
                        help=MAX_SEQ_LEN)
    parser.add_argument('--bufsize', type=int, default=2048, help=BUFSIZE)
    parser.add_argument('--batch_size', type=int, default=4, help=BATCH_SIZE)
    parser.add_argument('--multiplier', type=int, default=5, help=MULTIPLIER)
    parser.add_argument('--workers', type=int, default=10, help=WORKERS)
    parser.add_argument('--data_dir',
                        type=Path,
                        default=PKG_PATH.joinpath('.cache'),
                        help=DATA_DIR)
    parser.add_argument('--config',
                        type=str,
                        default='elasticsearch',
                        choices=CONFIG_MAP.keys(),
                        help=CONFIG)
    parser.add_argument('--model', type=str, default='', help=MODEL)
    parser.add_argument('--model_dir',
                        type=str,
                        default='pt-tinybert-msmarco',
                        help=MODEL_DIR)
    parser.add_argument('--qa', type=bool, default=False, help=QA)
    parser.add_argument('--qa_model', type=str, default='', help=QA_MODEL)
    parser.add_argument('--qa_model_dir',
                        type=str,
                        default='distilbert-base-uncased-distilled-squad',
                        help=QA_MODEL_DIR)
    parser.add_argument('--filter_results',
                        type=bool,
                        default=False,
                        help=FILTER_RESULTS)
    return parser