示例#1
0
def on_server_request(session: Session):
    """Send magnified response to the server"""
    request = deepcopy(session.request)
    request['headers'].pop('host', '')
    request['body'].pop('nboost', '')

    for default in defaults.__dict__:
        request['url']['query'].pop(default, '')

    try:
        response = requests.request(
            method=request['method'],
            url='{protocol}://{host}:{port}{path}'.format(
                protocol='https' if session.ussl else 'http',
                host=session.uhost,
                port=session.uport,
                path=unparse_url(request['url'])
            ),
            headers=request['headers'],
            json=request['body']
        )

    except ConnectionError as exc:
        raise UpstreamServerError(exc)

    session.response['status'] = response.status_code
    session.response['headers'] = {k.lower(): v for k, v in response.headers.items()}
    session.response['body'].update(response.json())
    session.response['version'] = 'HTTP/1.1'
    session.response['reason'] = response.reason

    if response.status_code >= 400:
        raise UpstreamServerError(response.text)

    session.stats['choices'] = len(session.choices)
示例#2
0
    def test_request_2(self):
        session = Session()
        session.request['url'] = {
            "query":
                {"q": "message:test query", "size": 20}
        }

        self.assertEqual("message:test query", session.query)
示例#3
0
 def test_request_3(self):
     session = Session()
     session.request['body'] = {
         "query": {
             "match": "hello there"
         },
     }
     self.assertEqual("hello there", session.query)
示例#4
0
def on_status_request(client_socket, session: Session, status: dict):
    """Send a the static frontend to the client."""
    protocol = HttpProtocol()
    protocol.set_response(session.response)
    session.response['body'] = status
    session.response['body'] = dump_json(session.response['body'], indent=2)
    prepared_response = prepare_response(session.response)
    client_socket.send(prepared_response)
示例#5
0
def on_proxy_error(client_socket, exc: Exception):
    """Send internal server error to the client."""
    session = Session()
    protocol = HttpProtocol()
    protocol.set_response(session.response)
    session.response['body'] = dump_json({'error': repr(exc)}, indent=2)
    session.response['status'] = 500
    prepared_response = prepare_response(session.response)
    client_socket.send(prepared_response)
示例#6
0
    def test_response_1(self):
        session = Session()
        session.response['body'] = {
            "nboost": {'cvalues_path': '_source.message'},
            "took": 5,
            "timed_out": False,
            "_shards": {
                "total": 1,
                "successful": 1,
                "skipped": 0,
                "failed": 0
            },
            "hits": {
                "total": {
                    "value": 1,
                    "relation": "eq"
                },
                "max_score": 1.3862944,
                "hits": [
                    {
                        "_index": "twitter",
                        "_type": "_doc",
                        "_id": "0",
                        "_score": 1.4,
                        "_source": {
                            "message": "trying out Elasticsearch",
                        }
                    }, {
                        "_index": "twitter",
                        "_type": "_doc",
                        "_id": "1",
                        "_score": 1.34245,
                        "_source": {
                            "message": "second result",
                        }
                    },
                    {
                        "_index": "twitter",
                        "_type": "_doc",
                        "_id": "2",
                        "_score": 1.121234,
                        "_source": {
                            "message": "third result",
                        }
                    }
                ]
            }
        }

        self.assertEqual(1.4, session.choices[0]['_score'])

        self.assertEqual(["trying out Elasticsearch", "second result", "third result"], session.cvalues)
        self.assertEqual(['0', '1', '2'], session.cids)
示例#7
0
    def test_request_5(self):
        session = Session()
        session.request['url']['query']['query_path'] = 'body.params.query'
        session.request['body'] = {
            "id": "searchTemplate",
            "params": {
                "query": "my query",
                "from": 0,
                "size": 9
            }
        }

        self.assertEqual('my query', session.query)
示例#8
0
    def test_request_1(self):
        session = Session()
        session.request['body'] = {
            "from": 0, "size": 20,
            "query": {
                "term": {"user": "******"}
            },
            "nboost": {
                "cids": ['0', '2']
            }
        }

        self.assertEqual("kimchy", session.query)
        self.assertEqual(20, session.topk)
示例#9
0
def on_debug(session: Session):
    """Add session configs to nboost response for debugging."""
    for config in session.cli_configs:
        session.add_nboost_response(config, getattr(session, config))

    session.add_nboost_response('query', session.query)
    session.add_nboost_response('topk', session.stats['topk'])
    session.add_nboost_response('cvalues', session.cvalues)
示例#10
0
def on_rerank_response(session: Session, model: RerankModel):
    """Returns the time the model takes to rerank."""
    if session.rerank_cids:
        session.stats['server_mrr'] = calculate_mrr(session.rerank_cids, session.cids)

    # this is hacky and needs to be fixed
    topk = session.stats['topk']

    start_time = time.perf_counter()
    ranks = model.rank(session.query, session.cvalues,
                       filter_results=session.filter_results)[:topk]
    session.stats['rerank_time'] = time.perf_counter() - start_time
    reranked_choices = [session.choices[rank] for rank in ranks]
    session.set_response_path(session.choices_path, reranked_choices)

    if session.rerank_cids:
        session.stats['model_mrr'] = calculate_mrr(session.rerank_cids, session.cids)
示例#11
0
def on_frontend_request(client_socket: socket.socket, session: Session):
    """Send a the static frontend to the client."""
    protocol = HttpProtocol()
    frontend_path = Path(__file__).parent.joinpath('resources/frontend')
    protocol.set_response(session.response)
    url_path = session.request['url']['path']

    if url_path == '/nboost':
        asset = 'index.html'
    else:
        asset = url_path.replace('/nboost/', '', 1)

    asset_path = frontend_path.joinpath(asset)

    # for security reasons, make sure there is no access to other dirs
    if frontend_path in asset_path.parents and asset_path.exists():
        session.response['body'] = asset_path.read_bytes()
    else:
        session.response['body'] = frontend_path.joinpath('404.html').read_bytes()

    prepared_response = prepare_response(session.response)
    client_socket.send(prepared_response)
示例#12
0
    def test_request_4(self):
        session = Session(query_path='body.query.function_score.query.bool.should.[*].match.text.query')
        session.request['body'] = {
            "size": 11,
            "query": {
                "function_score": {
                    "query": {
                        "bool": {
                            "should": [
                                {
                                    "match": {
                                        "text": {
                                            "query": "query one",
                                            "operator": "and"
                                        }
                                    }
                                },
                                {
                                    "match": {
                                        "text": {
                                            "query": "query two",
                                            "operator": "or"
                                        }
                                    }
                                }
                            ]
                        }
                    },
                    "script_score": {
                        "script": {
                            "source": "1 + ((5 - doc[\"priority\"].value) / 10.0) + ((doc[\"branch\"].value == \"All\") ? 0.5 : 0)"
                        }
                    }
                }
            }
        }

        self.assertEqual('query one. query two', session.query)
示例#13
0
    def test_response(self):
        protocol = HttpProtocol()
        session = Session()
        protocol.set_response_parser()
        protocol.set_response(session.response)

        protocol.feed(RESPONSE_PART_1)
        self.assertFalse(protocol._is_done)
        self.assertEqual(session.response['status'], 201)
        self.assertEqual(session.response['reason'], 'Created')

        protocol.feed(RESPONSE_PART_2)
        self.assertTrue(protocol._is_done)
        self.assertEqual(session.response['headers']['test-header'], '2')
        self.assertEqual({
            'nboost': {},
            'test': 'response'
        }, session.response['body'])
示例#14
0
    def test_request(self):
        protocol = HttpProtocol()
        protocol.set_request_parser()
        session = Session()
        protocol.set_request(session.request)
        protocol.add_url_hook(
            lambda url: self.assertEqual(url['path'], '/search'))
        protocol.add_data_hook(lambda data: self.assertIsInstance(data, bytes))

        protocol.feed(REQUEST_PART_1)
        self.assertFalse(protocol._is_done)
        self.assertEqual(session.request['method'], 'GET')
        self.assertEqual(session.request['url']['query']['para'], 'message')

        protocol.feed(REQUEST_PART_2)
        self.assertTrue(protocol._is_done)
        self.assertEqual(session.request['headers']['test-header'], 'Testing')
        self.assertEqual({'test': 'request'}, session.request['body'])
示例#15
0
def on_qa(session: Session, qa_model: QAModel):
    """Returns the qa time."""

    if session.cvalues:
        start_time = time.perf_counter()
        answer, start_pos, stop_pos, score = qa_model.get_answer(session.query, session.cvalues[0])
        session.stats['qa_time'] = time.perf_counter() - start_time

        if score > session.qa_threshold:
            session.add_nboost_response('answer_text', answer)
            session.add_nboost_response('answer_start_pos', start_pos)
            session.add_nboost_response('answer_stop_pos', stop_pos)

        if session.cids:
            first_choice_id = session.cids[0]
            if first_choice_id in session.qa_cids:
                qa_start_pos, qa_end_pos = session.qa_cids[first_choice_id]
                session.stats['qa_overlap'] = calculate_overlap(qa_start_pos,
                                                                qa_end_pos,
                                                                qa_start_pos,
                                                                qa_end_pos)
示例#16
0
 def loop(self, client_socket, address):
     session = Session()
     session.response['body'] = dump_json(RESPONSE)
     prepared_response = prepare_response(session.response)
     client_socket.send(prepared_response)
     client_socket.close()
示例#17
0
def on_rerank_request(session: Session):
    """Magnify the size of the request to topn results."""
    session.stats['topk'] = session.topk
    session.set_request_path(session.topk_path, session.topn)
示例#18
0
def on_client_response(session: Session, client_socket):
    """Send the ranked results to the client"""
    kwargs = dict(indent=2) if 'pretty' in session.request['url']['query'] else {}
    session.response['body'] = dump_json(session.response['body'], **kwargs)
    prepared_response = prepare_response(session.response)
    client_socket.send(prepared_response)
示例#19
0
 def get_session(self):
     return Session(**self.kwargs)