Beispiel #1
0
def encode_selected_predictions(predictions, selected_content_keys, accept):
    """Encode the selected predictions and keys based on the given accept type.

    :param predictions: list of selected predictions (list of dict).
                        Output of serve_utils.get_selected_predictions(...)
                        See example below.
                        [{"predicted_label": 1, "probabilities": [0.4, 0.6]},
                         {"predicted_label": 0, "probabilities": [0.9, 0.1]}]
    :param selected_content_keys: list of selected content keys (list of str)
    :param accept: accept mime-type (str)
    :return: encoded content in accept
    """
    if accept == "application/json":
        return json.dumps({"predictions": predictions})
    if accept == "application/jsonlines":
        return json_to_jsonlines({"predictions": predictions})
    if accept == "application/x-recordio-protobuf":
        return _encode_selected_predictions_recordio_protobuf(predictions)
    if accept == "text/csv":
        csv_response = _encode_selected_predictions_csv(
            predictions, selected_content_keys)
        if SAGEMAKER_BATCH:
            return csv_response + '\n'
        return csv_response
    raise RuntimeError(
        "Cannot encode selected predictions into accept type '{}'.".format(
            accept))
Beispiel #2
0
def test_encoder_jsonlines_from_json():
    json_response = json.dumps({'predictions': [{"predicted_label": 1, "probabilities": [0.4, 0.6]},
                                                {"predicted_label": 0, "probabilities": [0.9, 0.1]}]})
    expected_jsonlines = b'{"predicted_label": 1, "probabilities": [0.4, 0.6]}\n' \
                         b'{"predicted_label": 0, "probabilities": [0.9, 0.1]}\n'

    jsonlines_response = encoder.json_to_jsonlines(json_response)
    assert expected_jsonlines == jsonlines_response
Beispiel #3
0
def test_encoder_jsonlines_from_json_error():
    bad_json_response = json.dumps({'predictions': [], 'metadata': []})
    with pytest.raises(ValueError):
        encoder.json_to_jsonlines(bad_json_response)