예제 #1
0
def _handle_selectable_inference_response(predictions, accept):
    """Retrieve the additional prediction data for selectable inference mode.

    :param predictions: output of xgboost predict (list of numpy objects)
    :param accept: requested accept type (str)
    :return: flask response with encoded predictions
    """
    try:
        config = ScoringService.get_config_json()
        objective = config['learner']['objective']['name']
        num_class = config['learner']['learner_model_param'].get(
            'num_class', '')
        selected_content_keys = serve_utils.get_selected_output_keys()

        selected_content = serve_utils.get_selected_predictions(
            predictions, selected_content_keys, objective, num_class=num_class)

        response = serve_utils.encode_selected_predictions(
            selected_content, selected_content_keys, accept)
    except Exception as e:
        logging.exception(e)
        return flask.Response(response=str(e),
                              status=http.client.INTERNAL_SERVER_ERROR)

    return flask.Response(response=response,
                          status=http.client.OK,
                          mimetype=accept)
def test_encode_selected_predictions_protobuf():
    expected_predicted_labels = [[1], [0]]
    expected_probabilities = [[0.4, 0.6], [0.9, 0.1]]

    protobuf_response = serve_utils.encode_selected_predictions(TEST_PREDICTIONS, TEST_KEYS,
                                                                "application/x-recordio-protobuf")
    stream = io.BytesIO(protobuf_response)

    for recordio, predicted_label, probabilities in zip(_read_recordio(stream),
                                                        expected_predicted_labels, expected_probabilities):
        record = Record()
        record.ParseFromString(recordio)
        assert record.label["predicted_label"].float32_tensor.values == predicted_label
        assert all(np.isclose(record.label["probabilities"].float32_tensor.values, probabilities))
def test_encode_selected_content_error():
    with pytest.raises(RuntimeError):
        serve_utils.encode_selected_predictions(TEST_PREDICTIONS, TEST_KEYS, "text/libsvm")
def test_encode_selected_predictions_csv():
    expected_csv = '1,"[0.4, 0.6]"\n0,"[0.9, 0.1]"'
    assert serve_utils.encode_selected_predictions(TEST_PREDICTIONS, TEST_KEYS, "text/csv") == expected_csv
def test_encode_selected_predictions_jsonlines():
    expected_jsonlines = b'{"predicted_label": 1, "probabilities": [0.4, 0.6]}\n' \
                         b'{"predicted_label": 0, "probabilities": [0.9, 0.1]}\n'
    assert serve_utils.encode_selected_predictions(TEST_PREDICTIONS, TEST_KEYS,
                                                   "application/jsonlines") == expected_jsonlines
def test_encode_selected_predictions_json():
    expected_json = json.dumps({"predictions": TEST_PREDICTIONS})
    assert serve_utils.encode_selected_predictions(TEST_PREDICTIONS, TEST_KEYS, "application/json") == expected_json