def csv_generator(self) -> Generator: """Check for the csv in the current working directory first, then search for it in the package. Generates id_col, passage """ cwd_path = Path().joinpath(self.file).absolute() pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file) if cwd_path.exists(): path = cwd_path elif pkg_path.exists(): path = pkg_path else: self.logger.error('Could not find %s or %s', pkg_path, cwd_path) raise SystemExit self.logger.info('Estimating completion size...') num_lines = count_lines(path) with path.open() as file: with tqdm(total=num_lines, desc=path.name) as pbar: for cid, passage in csv.reader(file, delimiter=self.delim): yield cid, passage pbar.update()
def csv_generator(self) -> Generator: """Check for the csv in the current working directory first, then search for it in the package. Generates id_col and dict of {<column name>: <column value>} """ cwd_path = Path().joinpath(self.file).absolute() pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file) if cwd_path.exists(): path = cwd_path elif pkg_path.exists(): path = pkg_path else: self.logger.error('Could not find %s or %s', pkg_path, cwd_path) raise SystemExit self.logger.info('Estimating completion size...') num_lines = count_lines(path) with path.open() as file: with tqdm(total=num_lines, desc=path.name) as pbar: for line in csv.DictReader(file, delimiter=self.delim): cid = None if self.id_col: cid = line.popitem(last=False)[1] yield cid, dict(line) pbar.update()
def __init__(self, lr: float = 10e-3, model_dir: str = 'bert-base-uncased-msmarco', data_dir: Path = PKG_PATH.joinpath('.cache'), max_seq_len: int = 128, batch_size: int = 4, **_): super().__init__() self.lr = lr self.max_seq_len = max_seq_len self.batch_size = batch_size self.data_dir = data_dir if not os.path.exists(model_dir): self.model_dir = data_dir.joinpath(model_dir).absolute() else: self.model_dir = Path(model_dir) self.logger = set_logger(model_dir)
def set_parser() -> ArgumentParser: """Add default nboost cli arguments to a given parser""" parser = ArgumentParser(description=DESCRIPTION) parser.add_argument('--verbose', action='store_true', default=False, help=VERBOSE) parser.add_argument('--host', type=str, default='0.0.0.0', help=HOST) parser.add_argument('--port', type=int, default=8000, help=PORT) parser.add_argument('--uhost', type=str, default='0.0.0.0', help=UHOST) parser.add_argument('--uport', type=int, default=9200, help=UPORT) parser.add_argument('--lr', type=float, default=10e-3, help=LR) parser.add_argument('--model_dir', type=str, default='bert-base-uncased-msmarco', help=MODEL_DIR) parser.add_argument('--data_dir', type=Path, default=PKG_PATH.joinpath('.cache'), help=DATA_DIR) parser.add_argument('--max_seq_len', type=int, default=64, help=MAX_SEQ_LEN) parser.add_argument('--bufsize', type=int, default=2048, help=BUFSIZE) parser.add_argument('--batch_size', type=int, default=4, help=BATCH_SIZE) parser.add_argument('--multiplier', type=int, default=5, help=MULTIPLIER) parser.add_argument('--workers', type=int, default=10, help=WORKERS) parser.add_argument('--codex', type=lambda x: import_class('codex', x), default='ESCodex', help=PROTOCOL) parser.add_argument('--model', type=lambda x: import_class('model', x), default='BertModel', help=MODEL) return parser
def csv_generator(self) -> Generator: """yield the `--id_col` and `--field_col` from the `--file` csv""" pkg_path = PKG_PATH.joinpath('resources').joinpath(self.file) cwd_path = Path().joinpath(self.file).absolute() if pkg_path.exists(): path = pkg_path elif cwd_path.exists(): path = cwd_path else: self.logger.error('Could not find %s or %s', pkg_path, cwd_path) raise SystemExit self.logger.info('Estimating completion size...') num_lines = count_lines(path) with path.open() as file: with tqdm(total=num_lines, desc=path.name) as pbar: for line in csv.reader(file, delimiter=self.delim): pbar.update() yield line[self.id_col], line[self.field_col]
def __init__(self, data_dir: Path = PKG_PATH.joinpath('.cache'), model_dir: str = 'pt-bert-base-uncased-msmarco', qa_model_dir: str = 'distilbert-base-uncased-distilled-squad', qa_model: str = str(), model: str = str(), qa: bool = False, config: str = 'elasticsearch', verbose: bool = False, **kwargs): super().__init__(**kwargs) self.qa = qa self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) self.model_dir = data_dir.joinpath(model_dir).absolute() self.model = self.resolve_model(self.model_dir, model, verbose=verbose, **kwargs) self.logger = set_logger(self.model.__class__.__name__, verbose=verbose) if qa: self.qa_model_dir = data_dir.joinpath(qa_model_dir).absolute() self.qa_model = self.resolve_model(self.qa_model_dir, qa_model, **kwargs) # these are global parameters that are overrided by nboost json key self.config = { 'model': self.model.__class__.__name__, 'model_dir': model_dir, 'qa_model': self.qa_model.__class__.__name__ if qa else None, 'qa_model_dir': qa_model_dir if qa else None, 'data_dir': str(data_dir), **CONFIG_MAP[config], **kwargs }
class MsMarco(Benchmarker): """MSMARCO dataset benchmarker""" DEFAULT_URL = ('https://msmarco.blob.core.windows.net' '/msmarcoranking/collectionandqueries.tar.gz') BASE_DATASET_DIR = PKG_PATH.joinpath('.cache/datasets/ms_marco') def __init__(self, args): super().__init__(args) if not args.url: self.url = self.DEFAULT_URL else: self.url = args.url archive_file = self.url.split('/')[-1] archive_name = archive_file.split('.')[0] self.dataset_dir = self.BASE_DATASET_DIR.joinpath(archive_name) self.tar_gz_path = self.dataset_dir.joinpath(archive_file) self.qrels_tsv_path = self.dataset_dir.joinpath('qrels.dev.small.tsv') self.queries_tsv_path = self.dataset_dir.joinpath('queries.dev.tsv') self.collections_tsv_path = self.dataset_dir.joinpath('collection.tsv') self.index = 'ms_marco_' + archive_name # DOWNLOAD MSMARCO if not self.dataset_dir.exists(): self.dataset_dir.mkdir(parents=True, exist_ok=True) self.logger.info('Dowloading MSMARCO to %s' % self.tar_gz_path) download_file(self.url, self.tar_gz_path) self.logger.info('Extracting MSMARCO') extract_tar_gz(self.tar_gz_path, self.dataset_dir) self.tar_gz_path.unlink() self.proxy_es = Elasticsearch(host=self.args.host, port=self.args.port, timeout=REQUEST_TIMEOUT) self.direct_es = Elasticsearch(host=self.args.uhost, port=self.args.uport, timeout=REQUEST_TIMEOUT) collection_size = 0 with open(self.collections_tsv_path) as collection: for _ in collection: collection_size += 1 # INDEX MSMARCO try: if self.direct_es.count( index=self.index)['count'] < collection_size: raise elasticsearch.exceptions.NotFoundError except elasticsearch.exceptions.NotFoundError: try: self.direct_es.indices.create(index=self.index, body={ 'settings': { 'index': { 'number_of_shards': args.shards } } }) except: pass self.logger.info('Indexing %s' % self.collections_tsv_path) es_bulk_index(self.direct_es, self.stream_msmarco_full()) self.logger.info('Reading %s' % self.qrels_tsv_path) with self.qrels_tsv_path.open() as file: qrels = csv.reader(file, delimiter='\t') for qid, _, doc_id, _ in qrels: self.add_qrel(qid, doc_id) self.logger.info('Reading %s' % self.queries_tsv_path) with self.queries_tsv_path.open() as file: queries = csv.reader(file, delimiter='\t') for qid, query in queries: self.add_query(qid, query) def stream_msmarco_full(self): self.logger.info('Optimizing streamer...') num_lines = sum(1 for _ in self.collections_tsv_path.open()) with self.collections_tsv_path.open() as fh: data = csv.reader(fh, delimiter='\t') with tqdm(total=num_lines, desc='INDEXING MSMARCO') as pbar: for ident, passage in data: body = dict(_index=self.index, _id=ident, _source={'passage': passage}) yield body pbar.update() def proxied_doc_id_producer(self, query: str): return self.es_doc_id_producer(self.proxy_es, query) def direct_doc_id_producer(self, query: str): return self.es_doc_id_producer(self.direct_es, query) def es_doc_id_producer(self, es: Elasticsearch, query: str): body = dict(size=self.args.topk, query={"match": { "passage": { "query": query } }}) res = es.search(index=self.index, body=body, filter_path=['hits.hits._*']) doc_ids = [hit['_id'] for hit in res['hits']['hits']] return doc_ids
def __init__( self, host: type(defaults.host) = defaults.host, port: type(defaults.port) = defaults.port, verbose: type(defaults.verbose) = defaults.verbose, data_dir: type(defaults.data_dir) = defaults.data_dir, no_rerank: type(defaults.no_rerank) = defaults.no_rerank, model: type(defaults.model) = defaults.model, model_dir: type(defaults.model_dir) = defaults.model_dir, qa: type(defaults.qa) = defaults.qa, qa_model: type(defaults.qa_model) = defaults.qa_model, qa_model_dir: type(defaults.qa_model_dir) = defaults.qa_model_dir, search_route: type(defaults.search_route) = defaults.search_route, frontend_route: type( defaults.frontend_route) = defaults.frontend_route, status_route: type(defaults.status_route) = defaults.status_route, debug: type(defaults.debug) = defaults.debug, prerank: type(defaults.prerank) = defaults.prerank, **cli_args): self.logger = set_logger(self.__class__.__name__, verbose=verbose) BackwardsCompatibility().set() db = Database() plugins = [] # type: List[Plugin] if prerank: preRankPlugin = PrerankPlugin() plugins.append(preRankPlugin) if not no_rerank: rerank_model_plugin = resolve_model( data_dir=data_dir, model_dir=model_dir, model_cls=model, **cli_args) # type: RerankModelPlugin plugins.append(rerank_model_plugin) if qa: qa_model_plugin = resolve_model(data_dir=data_dir, model_dir=qa_model_dir, model_cls=qa_model, **cli_args) # type: QAModelPlugin plugins.append(qa_model_plugin) if debug: debug_plugin = DebugPlugin(**cli_args) plugins.append(debug_plugin) static_dir = str(PKG_PATH.joinpath('resources/frontend')) flask_app = Flask(__name__) @flask_app.route(frontend_route, methods=['GET']) def frontend_root(): return send_from_directory(static_dir, 'index.html') @flask_app.route(frontend_route + '/<path:path>', methods=['GET']) def frontend_path(path): return send_from_directory(static_dir, path) @flask_app.route(frontend_route + status_route) def status_path(): configs = {} for plugin in plugins: configs.update(plugin.configs) stats = db.get_stats() return jsonify({**configs, **stats}) flask_app.url_map.add(Rule('/<path:path>', endpoint='proxy')) @flask_app.route('/', defaults={'path': ''}) @flask_app.endpoint('proxy') def proxy_through(path): # parse the client request dict_request = flask_request_to_dict_request( flask_request) # takes the json """Search request.""" db_row = db.new_row() # combine command line args and runtime args sent by request query_args = {} for key in list(dict_request['url']['query']): if key in defaults.__dict__: query_args[key] = dict_request['url']['query'].pop(key) json_args = dict_request['body'].pop('nboost', {}) args = {**cli_args, **json_args, **query_args} request = RequestDelegate(dict_request, **args) request.dict['headers'].pop('Host', '') request.set_path('url.headers.host', '%s:%s' % (request.uhost, request.uport)) request.set_path('url.netloc', '%s:%s' % (request.uhost, request.uport)) request.set_path('url.scheme', 'https' if request.ussl else 'http') for plugin in plugins: # type: Plugin plugin.on_request(request, db_row) # get response from upstream server start_time = perf_counter() requests_response = dict_request_to_requests_response(dict_request) db_row.response_time = perf_counter() - start_time try: dict_response = requests_response_to_dict_response( requests_response) except JSONDecodeError: print(requests_response.content) return requests_response.content response = ResponseDelegate(dict_response, request) response.set_path('body.nboost', {}) db_row.choices = len(response.choices) for plugin in plugins: # type: Plugin plugin.on_response(response, db_row) # save stats to sql lite db.insert(db_row) return dict_response_to_flask_response(dict_response) @flask_app.errorhandler(Exception) def handle_json_response(error): self.logger.error('', exc_info=True) return jsonify({ 'type': error.__class__.__name__, 'doc': error.__class__.__doc__, 'msg': str(error.args) }), 500 self.run = lambda: (self.logger.critical('LISTENING %s:%s' % ( host, port)) or flask_app.run(host=host, port=port))
host = '0.0.0.0' port = 8000 uhost = '0.0.0.0' uport = 9200 ussl = False backlog = 100 verbose = False query_delim = '. ' lr = 10e-3 max_seq_len = 64 bufsize = 2048 batch_size = 4 topn = 50 workers = 10 data_dir = PKG_PATH.joinpath('.cache') no_rerank = False model = 'PtTransformersRerankPlugin' model_dir = 'nboost/pt-tinybert-msmarco' qa = False qa_model = 'PtDistilBertQAModelPlugin' qa_model_dir = 'distilbert-base-uncased-distilled-squad' qa_threshold = 0 max_query_length = 64 filter_results = False query_prep = 'lambda query: query' debug = False db_file = data_dir.joinpath('nboost.db') rerank_cids = ListOrCommaDelimitedString() prerank = False
def build(): """Build dockerfiles""" for image, path in IMAGE_MAP.items(): path = PKG_PATH.joinpath(path).absolute() execute(BUILD.format(image=image, path=path))
class Proxy(SocketServer): """The proxy object is the core of NBoost. The following __init__ contains the main executed functions in nboost. :param host: virtual host of the server. :param port: server port. :param uhost: host of the external search api. :param uport: search api port. :param multiplier: the factor to multiply the search request by. For example, in the case of Elasticsearch if the client requests 10 results and the multiplier is 6, then the model should receive 60 results to rank and refine down to 10 (better) results. :param field: a tag for the field in the search api result that the model should rank results by. :param model: uninitialized model class :param codex: uninitialized codex class """ # statistical contexts STATIC_PATH = PKG_PATH.joinpath('resources/frontend') stats = ClassStatistics() def __init__(self, model: Type[BaseModel], codex: Type[BaseCodex], uhost: str = '0.0.0.0', uport: int = 9200, bufsize: int = 2048, verbose: bool = False, **kwargs): super().__init__(**kwargs) self.kwargs = kwargs self.uaddress = (uhost, uport) self.bufsize = bufsize self.logger = set_logger(model.__name__, verbose=verbose) # pass command line arguments to instantiate each component self.model = model(verbose=verbose, **kwargs) self.codex = codex(verbose=verbose, **kwargs) def on_client_request_url(self, url: URL): """Method for screening the url path from the client request""" if url.path.startswith('/nboost'): raise FrontendRequest if not re.match(self.codex.SEARCH_PATH, url.path): raise UnknownRequest def set_protocol(self, sock: socket.socket) -> HttpProtocol: """Construct the protocol with the proxy settings""" protocol = HttpProtocol(sock) protocol.set_bufsize = self.bufsize return protocol @stats.time_context def proxy_send(self, client_socket, server_socket, buffer: bytearray): """Send buffered request to server and receive the rest of the original client request""" protocol = self.set_protocol(client_socket) protocol.set_request_parser() server_socket.send(buffer) protocol.feed(buffer) protocol.add_data_hook(server_socket.send) protocol.recv() @stats.time_context def proxy_recv(self, client_socket, server_socket): """Receive the proxied response and pipe to the client""" protocol = self.set_protocol(server_socket) protocol.set_response_parser() protocol.add_data_hook(client_socket.send) protocol.recv() @stats.time_context def client_recv(self, client_socket, request: Request, buffer: bytearray): """Receive client request and pipe to buffer in case of exceptions""" protocol = self.set_protocol(client_socket) protocol.set_request_parser() protocol.set_request(request) protocol.add_data_hook(buffer.extend) protocol.add_url_hook(self.on_client_request_url) protocol.recv() @staticmethod @stats.time_context def server_send(server_socket: socket.socket, request: Request): """Send magnified request to the server""" server_socket.send(request.prepare()) @stats.time_context def server_recv(self, server_socket: socket.socket, response: Response): """Receive magnified request from the server""" protocol = self.set_protocol(server_socket) protocol.set_response_parser() protocol.set_response(response) protocol.recv() @staticmethod @stats.time_context def client_send(request: Request, response: Response, client_socket): """Send the ranked results to the client""" raw_response = response.prepare(request) client_socket.send(raw_response) @stats.time_context def model_rank(self, query: bytes, choices: List[Choice]) -> List[int]: """Rank the query and choices and return the argsorted indices""" return self.model.rank(query, choices) @stats.vars_context def record_topk_and_choices(self, topk: int = None, choices: list = None): """Add topk and choices to the running statistical averages""" @stats.vars_context def record_mrrs(self, upstream_mrr: float = None, model_mrr: float = None): """Add the upstream mrr, model mrr, and search boost to the stats""" with suppress(ZeroDivisionError): var = self.stats.record['vars'] var['search_boost'] = { 'avg': var['model_mrr']['avg'] / var['upstream_mrr']['avg'] } @stats.time_context def server_connect(self, server_socket: socket.socket) -> None: """Connect proxied server socket""" try: server_socket.connect(self.uaddress) except ConnectionRefusedError: raise UpstreamConnectionError(*self.uaddress) @property def status(self) -> dict: """Return status dictionary in the case of a status request""" return { 'multiplier': self.codex.multiplier, **self.stats.record, 'description': 'NBoost, for search ranking.' } def calculate_mrrs(self, correct_cids: List[str], choices: List[Choice], ranks: List[int]): """Calculate the mrr of the upstream server and reranked choices from the model. This only occurs if the client specified the "nboost" parameter in the request url or body.""" upstream_mrr = self.calculate_mrr(correct_cids, choices) reranked_choices = [choices[rank] for rank in ranks] model_mrr = self.calculate_mrr(correct_cids, reranked_choices) self.record_mrrs(upstream_mrr=upstream_mrr, model_mrr=model_mrr) @staticmethod def calculate_mrr(correct_cids: List[str], choices: List[Choice]): """Calculate mean reciprocal rank as the first correct result index""" for i, choice in enumerate(choices, 1): if choice.cid in correct_cids: return 1 / i return 0 def get_static_file(self, path: str) -> bytes: """Construct the static path of the frontend asset requested and return the raw file.""" if path == '/nboost': asset = 'index.html' else: asset = path.replace('/nboost/', '', 1) static_path = self.STATIC_PATH.joinpath(asset) # for security reasons, make sure there is no access to other dirs if self.STATIC_PATH in static_path.parents and static_path.exists(): return static_path.read_bytes() else: return self.STATIC_PATH.joinpath('404.html').read_bytes() def loop(self, client_socket: socket.socket, address: Tuple[str, str]): """Main ioloop for reranking server results to the client. Exceptions raised in the http parser must be reraised from __context__ because they are caught by the MagicStack implementation""" server_socket = self.set_socket() buffer = bytearray() request = Request() response = Response() log = ('%s:%s %s', *address, request) try: self.server_connect(server_socket) with HttpParserContext(): # receive and buffer the client request self.client_recv(client_socket, request, buffer) self.logger.debug(*log) field, query = self.codex.parse_query(request) # magnify the size of the request to the server topk, correct_cids = self.codex.multiply_request(request) self.server_send(server_socket, request) # make sure server response comes back properly self.server_recv(server_socket, response) response.unpack() if response.status < 300: # parse the choices from the magnified response choices = self.codex.parse_choices(response, field) self.record_topk_and_choices(topk=topk, choices=choices) # use the model to rerank the choices ranks = self.model_rank(query, choices)[:topk] self.codex.reorder_response(request, response, ranks) # if the "nboost" param was sent, calculate MRRs if correct_cids is not None: self.calculate_mrrs(correct_cids, choices, ranks) self.client_send(request, response, client_socket) except FrontendRequest: self.logger.info(*log) if request.url.path == '/nboost/status': response.body = json.dumps(self.status, indent=2).encode() else: response.body = self.get_static_file(request.url.path) self.client_send(request, response, client_socket) except (UnknownRequest, MissingQuery): self.logger.warning(*log) # send the initial buffer that was used to check url path self.proxy_send(client_socket, server_socket, buffer) # stream the client socket to the server socket self.proxy_recv(client_socket, server_socket) except Exception as exc: # for misc errors, send back json error msg self.logger.error(repr(exc), exc_info=True) response.body = json.dumps(dict(error=repr(exc))).encode() response.status = 500 self.client_send(request, response, client_socket) finally: client_socket.close() server_socket.close() def run(self): """Same as socket server run() but logs""" self.logger.critical('Upstream host is %s:%s', *self.uaddress) super().run() def close(self): """Close the proxy server and model""" self.logger.info('Closing model...') self.model.close() super().close()
class Proxy(SocketServer): """The proxy object is the core of NBoost. The following __init__ contains the main executed functions in nboost. :param host: virtual host of the server. :param port: server port. :param uhost: host of the external search api. :param uport: search api port. :param multiplier: the factor to multiply the search request by. For example, in the case of Elasticsearch if the client requests 10 results and the multiplier is 6, then the model should receive 60 results to rank and refine down to 10 (better) results. :param field: a tag for the field in the search api result that the model should rank results by. :param model: uninitialized model class :param codex: uninitialized codex class """ stats = ClassStatistics() STATIC_PATH = PKG_PATH.joinpath('resources/frontend') def __init__(self, data_dir: Path = PKG_PATH.joinpath('.cache'), model_dir: str = 'pt-bert-base-uncased-msmarco', qa_model_dir: str = 'distilbert-base-uncased-distilled-squad', qa_model: str = str(), model: str = str(), qa: bool = False, config: str = 'elasticsearch', verbose: bool = False, **kwargs): super().__init__(**kwargs) self.qa = qa self.data_dir = data_dir self.data_dir.mkdir(parents=True, exist_ok=True) self.model_dir = data_dir.joinpath(model_dir).absolute() self.model = self.resolve_model(self.model_dir, model, verbose=verbose, **kwargs) self.logger = set_logger(self.model.__class__.__name__, verbose=verbose) if qa: self.qa_model_dir = data_dir.joinpath(qa_model_dir).absolute() self.qa_model = self.resolve_model(self.qa_model_dir, qa_model, **kwargs) # these are global parameters that are overrided by nboost json key self.config = { 'model': self.model.__class__.__name__, 'model_dir': model_dir, 'qa_model': self.qa_model.__class__.__name__ if qa else None, 'qa_model_dir': qa_model_dir if qa else None, 'data_dir': str(data_dir), **CONFIG_MAP[config], **kwargs } def resolve_model(self, model_dir: Path, cls: str, **kwargs): """Dynamically import class from a module in the CLASS_MAP. This is used to manage dependencies within nboost. For example, you don't necessarily want to import pytorch models everytime you boot up tensorflow...""" if model_dir.exists(): self.logger.info('Using model cache from %s', model_dir) if model_dir.name in CLASS_MAP: cls = CLASS_MAP[model_dir.name] elif cls not in MODULE_MAP: raise ImportError('Class "%s" not in %s.' % CLASS_MAP.keys()) module = MODULE_MAP[cls] model = import_class(module, cls) return model(str(model_dir), **kwargs) else: if model_dir.name in CLASS_MAP: cls = CLASS_MAP[model_dir.name] module = MODULE_MAP[cls] url = URL_MAP[model_dir.name] binary_path = self.data_dir.joinpath(Path(url).name) if binary_path.exists(): self.logger.info('Found model cache in %s', binary_path) else: self.logger.info('Downloading "%s" model.', model_dir) download_file(url, binary_path) if binary_path.suffixes == ['.tar', '.gz']: self.logger.info('Extracting "%s" from %s', model_dir, binary_path) extract_tar_gz(binary_path, self.data_dir) model = import_class(module, cls) return model(str(model_dir), **kwargs) else: if cls in MODULE_MAP: module = MODULE_MAP[cls] model = import_class(module, cls) return model(model_dir.name, **kwargs) else: raise ImportError('model_dir %s not found in %s. You must ' 'set --model class to continue.' % (model_dir.name, CLASS_MAP.keys())) def on_client_request_url(self, url: dict): """Method for screening the url path from the client request""" if url['path'].startswith('/nboost/status'): raise StatusRequest if url['path'].startswith('/nboost'): raise FrontendRequest if not re.match(self.config['capture_path'], url['path']): raise UnknownRequest def get_protocol(self) -> HttpProtocol: """Return a configured http protocol parser.""" return HttpProtocol(self.config['bufsize']) @stats.time_context def frontend_send(self, client_socket, request): """Send a the static frontend to the client.""" response = {} protocol = self.get_protocol() protocol.set_response_parser() protocol.set_response(response) response['body'] = self.get_static_file(request['url']['path']) client_socket.send(prepare_response(response)) @stats.time_context def status_send(self, client_socket, request): """Send a the static frontend to the client.""" response = {} protocol = self.get_protocol() protocol.set_response_parser() protocol.set_response(response) response['body'] = self.status response['body'] = dump_json(response['body'], indent=2) client_socket.send(prepare_response(response)) @stats.time_context def proxy_send(self, client_socket, server_socket, buffer: bytearray): """Send buffered request to server and receive the rest of the original client request""" protocol = self.get_protocol() protocol.set_request_parser() protocol.add_data_hook(server_socket.send) protocol.feed(buffer) protocol.recv(client_socket) @stats.time_context def proxy_recv(self, client_socket, server_socket): """Receive the proxied response and pipe to the client""" protocol = self.get_protocol() protocol.set_response_parser() protocol.add_data_hook(client_socket.send) protocol.recv(server_socket) @stats.time_context def client_recv(self, client_socket, request: dict, buffer: bytearray): """Receive client request and pipe to buffer in case of exceptions""" protocol = self.get_protocol() protocol.set_request_parser() protocol.set_request(request) protocol.add_data_hook(buffer.extend) protocol.add_url_hook(self.on_client_request_url) protocol.recv(client_socket) @staticmethod @stats.time_context def server_send(server_socket: socket.socket, request: dict): """Send magnified request to the server""" request['body'] = dump_json(request['body']) request['headers']['content-type'] = 'application/json; charset=UTF-8' server_socket.send(prepare_request(request)) @stats.time_context def server_recv(self, server_socket: socket.socket, response: dict): """Receive magnified request from the server""" protocol = self.get_protocol() protocol.set_response_parser() protocol.set_response(response) protocol.recv(server_socket) @staticmethod @stats.time_context def client_send(request: dict, response: dict, client_socket): """Send the ranked results to the client""" kwargs = dict(indent=2) if 'pretty' in request['url']['query'] else {} response['body'] = dump_json(response['body'], **kwargs) client_socket.send(prepare_response(response)) @stats.time_context def error_send(self, client_socket, exc: Exception): """Send internal server error to the client.""" response = {} protocol = self.get_protocol() protocol.set_response(response) response['body'] = dump_json({'error': repr(exc)}, indent=2) response['status'] = 500 client_socket.send(prepare_response(response)) @stats.time_context def model_rank(self, query: str, choices: List[str]) -> List[int]: """Rank the query and choices and return the argsorted indices""" return self.model.rank(query, choices) @stats.vars_context def record_topk_and_choices(self, topk: int = None, choices: list = None): """Add topk and choices to the running statistical averages""" @stats.vars_context def record_mrrs(self, upstream_mrr: float = None, model_mrr: float = None): """Add the upstream mrr, model mrr, and search boost to the stats""" with suppress(ZeroDivisionError): var = self.stats.record['vars'] var['search_boost'] = { 'avg': var['model_mrr']['avg'] / var['upstream_mrr']['avg'] } @stats.time_context def server_connect(self, server_socket: socket.socket): """Connect proxied server socket""" uaddress = (self.config['uhost'], self.config['uport']) try: server_socket.connect(uaddress) except ConnectionRefusedError: raise UpstreamConnectionError('Connect error for %s:%s' % uaddress) @property def status(self) -> dict: """Return status dictionary in the case of a status request""" return { **self.config, **self.stats.record, 'description': 'NBoost, for search ranking.' } def calculate_mrrs(self, true_cids: List[str], cids: List[str], ranks: List[int]): """Calculate the mrr of the upstream server and reranked choices from the model. This only occurs if the client specified the "nboost" parameter in the request url or body.""" upstream_mrr = self.calculate_mrr(true_cids, cids) reranked_cids = [cids[rank] for rank in ranks] model_mrr = self.calculate_mrr(true_cids, reranked_cids) self.record_mrrs(upstream_mrr=upstream_mrr, model_mrr=model_mrr) @staticmethod def calculate_mrr(correct: list, guesses: list): """Calculate mean reciprocal rank as the first correct result index""" for i, guess in enumerate(guesses, 1): if guess in correct: return 1 / i return 0 def get_static_file(self, path: str) -> bytes: """Construct the static path of the frontend asset requested and return the raw file.""" if path == '/nboost': asset = 'index.html' else: asset = path.replace('/nboost/', '', 1) static_path = self.STATIC_PATH.joinpath(asset) # for security reasons, make sure there is no access to other dirs if self.STATIC_PATH in static_path.parents and static_path.exists(): return static_path.read_bytes() else: return self.STATIC_PATH.joinpath('404.html').read_bytes() @staticmethod def get_request_paths(request, configs) -> Tuple[str, int, list]: """Get the request jsonpaths noted in the configs""" queries = get_jsonpath(request, configs['query_path']) topks = get_jsonpath(request, configs['topk_path']) true_cids = get_jsonpath(request, configs['true_cids_path']) # coerce request variables from their paths topk = int(topks[0]) if topks else configs['default_topk'] query = configs['delim'].join(queries) # check for errors if not query: raise MissingQuery return query, topk, true_cids @staticmethod def get_response_paths(response, configs) -> Tuple[list, list, list]: """Get the request jsonpaths noted in the configs""" choices = get_jsonpath(response, configs['choices_path']) if not isinstance(choices, list): raise InvalidChoices('choices were not a list') choices = flatten(choices) cids = get_jsonpath(choices, '[*].' + configs['cids_path']) cvalues = get_jsonpath(choices, '[*].' + configs['cvalues_path']) # check for errors if not len(choices) == len(cids) == len(cvalues): raise InvalidChoices('number of choices, cids, and cvalues differ') return choices, cids, cvalues def loop(self, client_socket: socket.socket, address: Tuple[str, str]): """Main ioloop for reranking server results to the client. Exceptions raised in the http parser must be reraised from __context__ because they are caught by the MagicStack implementation""" buffer = bytearray() server_socket = self.set_socket() request = {} response = {} try: self.server_connect(server_socket) with HttpParserContext(): # receive and buffer the client request self.client_recv(client_socket, request, buffer) self.logger.debug('Request (%s:%s): search.', *address) # combine runtime configs and preset configs configs = {**self.config, **request['body'].get('nboost', {})} query, topk, true_cids = self.get_request_paths( request, configs) # magnify the size of the request to the server new_topk = topk * configs['multiplier'] set_jsonpath(request, configs['topk_path'], new_topk) # send the magnified request to the upstream server self.server_send(server_socket, request) self.server_recv(server_socket, response) if response['status'] < 300: choices, cids, cvalues = self.get_response_paths( response, configs) self.record_topk_and_choices(topk=topk, choices=choices) # use the model to rerank the choices ranks = self.model_rank(query, cvalues)[:topk] reranked = [choices[rank] for rank in ranks] set_jsonpath(response, configs['choices_path'], reranked) # if the "nboost" param was sent, calculate MRRs if true_cids is not None: self.calculate_mrrs(true_cids, cids, ranks) response['body']['nboost'] = {} if self.qa and len(cvalues) > 0: answer, offsets, score = self.qa_model.get_answer( query, cvalues[ranks.index(min(ranks))]) response['body']['nboost']['qa_model'] = answer response['body']['nboost']['qa_model_offsets'] = offsets response['body']['nboost']['qa_model_score'] = score self.client_send(request, response, client_socket) except FrontendRequest: self.logger.info('Request (%s:%s): frontend request', *address) self.frontend_send(client_socket, request) except StatusRequest: self.logger.info('Request (%s:%s): status request', *address) self.status_send(client_socket, request) except UnknownRequest: self.logger.info('Request (%s:%s): unknown path "%s"', *address, request['url']['path']) self.proxy_send(client_socket, server_socket, buffer) self.proxy_recv(client_socket, server_socket) except MissingQuery: self.logger.warning('Request (%s:%s): missing query', *address) self.proxy_send(client_socket, server_socket, buffer) self.proxy_recv(client_socket, server_socket) except InvalidChoices as exc: self.logger.warning('Request (%s:%s): %s', *address, *exc.args) self.proxy_send(client_socket, server_socket, buffer) self.proxy_recv(client_socket, server_socket) except Exception as exc: # for misc errors, send back json error msg self.logger.error('Request (%s:%s): %s.', *address, exc, exc_info=True) self.error_send(client_socket, exc) finally: client_socket.close() server_socket.close() def run(self): """Same as socket server run() but logs""" self.logger.info(dump_json(self.config, indent=4).decode()) super().run() def close(self): """Close the proxy server and model""" self.logger.info('Closing model...') self.model.close() super().close()
def set_parser() -> ArgumentParser: """Add default nboost cli arguments to a given parser""" parser = ArgumentParser(prog='nboost', description=DESCRIPTION, formatter_class=lambda prog: AdHf( prog, max_help_position=100, width=100)) parser.add_argument('--capture_path', type=str, default=SUPPRESS, help=CAPTURE_PATH) parser.add_argument('--query_path', type=str, default=SUPPRESS, help=QUERY_PATH) parser.add_argument('--topk_path', type=str, default=SUPPRESS, help=TOPK_PATH) parser.add_argument('--cvalues_path', type=str, default=SUPPRESS, help=CVALUES_PATH) parser.add_argument('--cids_path', type=str, default=SUPPRESS, help=CIDS_PATH) parser.add_argument('--true_cids_path', type=str, default=SUPPRESS, help=TRUE_CIDS_PATH) parser.add_argument('--choices_path', type=str, default=SUPPRESS, help=CHOICES_PATH) parser.add_argument('--verbose', type=bool, default=False, help=VERBOSE) parser.add_argument('--host', type=str, default='0.0.0.0', help=HOST) parser.add_argument('--port', type=int, default=8000, help=PORT) parser.add_argument('--uhost', type=str, default='0.0.0.0', help=UHOST) parser.add_argument('--uport', type=int, default=9200, help=UPORT) parser.add_argument('--delim', type=str, default='. ', help=DELIM) parser.add_argument('--lr', type=float, default=10e-3, help=LR) parser.add_argument('--max_seq_len', type=int, default=64, help=MAX_SEQ_LEN) parser.add_argument('--bufsize', type=int, default=2048, help=BUFSIZE) parser.add_argument('--batch_size', type=int, default=4, help=BATCH_SIZE) parser.add_argument('--multiplier', type=int, default=5, help=MULTIPLIER) parser.add_argument('--workers', type=int, default=10, help=WORKERS) parser.add_argument('--data_dir', type=Path, default=PKG_PATH.joinpath('.cache'), help=DATA_DIR) parser.add_argument('--config', type=str, default='elasticsearch', choices=CONFIG_MAP.keys(), help=CONFIG) parser.add_argument('--model', type=str, default='', help=MODEL) parser.add_argument('--model_dir', type=str, default='pt-tinybert-msmarco', help=MODEL_DIR) parser.add_argument('--qa', type=bool, default=False, help=QA) parser.add_argument('--qa_model', type=str, default='', help=QA_MODEL) parser.add_argument('--qa_model_dir', type=str, default='distilbert-base-uncased-distilled-squad', help=QA_MODEL_DIR) parser.add_argument('--filter_results', type=bool, default=False, help=FILTER_RESULTS) return parser