Exemplo n.º 1
0
    def test_standard_model(self):
        app = make_app(predictor=self.bidaf_predictor,
                       field_names=["passage", "question"])
        app.testing = True
        client = app.test_client()

        # First test the HTML
        response = client.get("/")
        data = response.get_data()

        assert b"passage" in data
        assert b"question" in data

        # Now test the backend
        response = post_json(client, "/predict", PAYLOAD)
        data = json.loads(response.get_data())
        assert "best_span_str" in data
        assert "span_start_logits" in data

        # Test the batch predictor
        batch_size = 8
        response = post_json(client, "/predict_batch", [PAYLOAD] * batch_size)
        data_list = json.loads(response.get_data())
        assert len(data_list) == batch_size
        for data in data_list:
            assert "best_span_str" in data
            assert "span_start_logits" in data
Exemplo n.º 2
0
    def test_sanitizer(self):
        def sanitize(result: JsonDict) -> JsonDict:
            return {
                key: value
                for key, value in result.items() if key.startswith("best_span")
            }

        app = make_app(predictor=self.bidaf_predictor,
                       field_names=["passage", "question"],
                       sanitizer=sanitize)
        app.testing = True
        client = app.test_client()

        response = post_json(client, "/predict", PAYLOAD)
        data = json.loads(response.get_data())
        assert "best_span_str" in data
        assert "span_start_logits" not in data

        batch_size = 8
        response = post_json(client, "/predict_batch", [PAYLOAD] * batch_size)
        data_list = json.loads(response.get_data())
        assert len(data_list) == batch_size
        for data in data_list:
            assert "best_span_str" in data
            assert "span_start_logits" not in data
Exemplo n.º 3
0
    def test_sanitizer(self):
        def sanitize(result: JsonDict) -> JsonDict:
            return {key: value for key, value in result.items()
                    if key.startswith("best_span")}

        app = make_app(predictor=self.bidaf_predictor, field_names=['passage', 'question'], sanitizer=sanitize)
        app.testing = True
        client = app.test_client()

        response = post_json(client, '/predict', PAYLOAD)
        data = json.loads(response.get_data())
        assert 'best_span_str' in data
        assert 'span_start_logits' not in data
Exemplo n.º 4
0
    def test_sanitizer(self):
        def sanitize(result: JsonDict) -> JsonDict:
            return {key: value for key, value in result.items()
                    if key.startswith("best_span")}

        app = make_app(predictor=self.bidaf_predictor, field_names=['passage', 'question'], sanitizer=sanitize)
        app.testing = True
        client = app.test_client()

        response = post_json(client, '/predict', PAYLOAD)
        data = json.loads(response.get_data())
        assert 'best_span_str' in data
        assert 'span_start_logits' not in data
Exemplo n.º 5
0
    def test_sanitizer(self):
        def sanitize(result):
            return dict((key, value) for key, value in list(result.items())
                        if key.startswith(u"best_span"))

        app = make_app(predictor=self.bidaf_predictor,
                       field_names=[u'passage', u'question'],
                       sanitizer=sanitize)
        app.testing = True
        client = app.test_client()

        response = post_json(client, u'/predict', PAYLOAD)
        data = json.loads(response.get_data())
        assert u'best_span_str' in data
        assert u'span_start_logits' not in data
Exemplo n.º 6
0
    def test_standard_model(self):
        app = make_app(predictor=self.bidaf_predictor, field_names=['passage', 'question'])
        app.testing = True
        client = app.test_client()

        # First test the HTML
        response = client.get('/')
        data = response.get_data()

        assert b"passage" in data
        assert b"question" in data

        # Now test the backend
        response = post_json(client, '/predict', PAYLOAD)
        data = json.loads(response.get_data())
        assert 'best_span_str' in data
        assert 'span_start_logits' in data
Exemplo n.º 7
0
    def test_standard_model(self):
        app = make_app(predictor=self.bidaf_predictor, field_names=['passage', 'question'])
        app.testing = True
        client = app.test_client()

        # First test the HTML
        response = client.get('/')
        data = response.get_data()

        assert b"passage" in data
        assert b"question" in data

        # Now test the backend
        response = post_json(client, '/predict', PAYLOAD)
        data = json.loads(response.get_data())
        assert 'best_span_str' in data
        assert 'span_start_logits' in data
Exemplo n.º 8
0
def main():
    parser = argparse.ArgumentParser()
    env_utils.add_inference_argument(parser)
    parser.add_argument('-title', dest='title', type=str, default='SpanNet')
    parser.add_argument('-port', dest='port', type=int, default=5000)
    parser.add_argument('-log-file', dest='log_file', type=str)
    args = parser.parse_args()
    env_utils.pre_logger(args.log_file)

    predictor = load_predictor(model_path=args.model, device=args.device)

    app = make_app(predictor=predictor, field_names=['text'], title=args.title)
    CORS(app)

    http_server = WSGIServer(('0.0.0.0', args.port), app)
    print(f"Model loaded, serving demo on port {args.port}")
    http_server.serve_forever()
Exemplo n.º 9
0
    def test_static_dir(self):
        html = """<html><body>THIS IS A STATIC SITE</body></html>"""
        jpg = """something about a jpg"""

        with open(os.path.join(self.TEST_DIR, 'index.html'), 'w') as f:
            f.write(html)

        with open(os.path.join(self.TEST_DIR, 'jpg.txt'), 'w') as f:
            f.write(jpg)

        app = make_app(predictor=self.bidaf_predictor, static_dir=self.TEST_DIR)
        app.testing = True
        client = app.test_client()

        response = client.get('/')
        data = response.get_data().decode('utf-8')
        assert data == html

        response = client.get('jpg.txt')
        data = response.get_data().decode('utf-8')
        assert data == jpg
Exemplo n.º 10
0
    def test_static_dir(self):
        html = """<html><body>THIS IS A STATIC SITE</body></html>"""
        jpg = """something about a jpg"""

        with open(os.path.join(self.TEST_DIR, 'index.html'), 'w') as f:
            f.write(html)

        with open(os.path.join(self.TEST_DIR, 'jpg.txt'), 'w') as f:
            f.write(jpg)

        app = make_app(predictor=self.bidaf_predictor, static_dir=self.TEST_DIR)
        app.testing = True
        client = app.test_client()

        response = client.get('/')
        data = response.get_data().decode('utf-8')
        assert data == html

        response = client.get('jpg.txt')
        data = response.get_data().decode('utf-8')
        assert data == jpg