Пример #1
0
def test_classification_request_pb(sagemaker_session):
    request = classification_pb2.ClassificationRequest()
    request.model_spec.name = "generic_model"
    request.model_spec.signature_name = DEFAULT_SERVING_SIGNATURE_DEF_KEY
    example = request.input.example_list.examples.add()
    example.features.feature[PREDICT_INPUTS].float_list.value.extend([6.4, 3.2, 4.5, 1.5])

    predictor = RealTimePredictor(sagemaker_session=sagemaker_session,
                                  endpoint=ENDPOINT,
                                  deserializer=tf_deserializer,
                                  serializer=tf_serializer)

    expected_response = classification_pb2.ClassificationResponse()
    classes = expected_response.result.classifications.add().classes

    class_0 = classes.add()
    class_0.label = "0"
    class_0.score = 0.00128903763834

    class_1 = classes.add()
    class_1.label = "1"
    class_1.score = 0.981432199478

    class_2 = classes.add()
    class_2.label = "2"
    class_2.score = 0.0172787327319

    mock_response(expected_response.SerializeToString(), sagemaker_session, PROTO_CONTENT_TYPE)

    result = predictor.predict(request)

    sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
        Accept=PROTO_CONTENT_TYPE,
        Body=request.SerializeToString(),
        ContentType=PROTO_CONTENT_TYPE,
        EndpointName='myendpoint'
    )

    # python 2 and 3 protobuf serialization has different precision so I'm checking
    # the version here
    if sys.version_info < (3, 0):
        assert str(result) == """result {
  classifications {
    classes {
      label: "0"
      score: 0.00128903763834
    }
    classes {
      label: "1"
      score: 0.981432199478
    }
    classes {
      label: "2"
      score: 0.0172787327319
    }
  }
}
"""
    else:
        assert str(result) == """result {
Пример #2
0
def test_classification_request_csv(sagemaker_session):
    data = [1, 2, 3]
    predictor = RealTimePredictor(serializer=tf_csv_serializer,
                                  deserializer=tf_deserializer,
                                  sagemaker_session=sagemaker_session,
                                  endpoint=ENDPOINT)

    expected_response = json_format.Parse(
        json.dumps(CLASSIFICATION_RESPONSE), classification_pb2.ClassificationResponse()
    ).SerializeToString()

    mock_response(expected_response, sagemaker_session, PROTO_CONTENT_TYPE)

    result = predictor.predict(data)

    sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with(
        Accept=PROTO_CONTENT_TYPE,
        Body='1,2,3',
        ContentType=CSV_CONTENT_TYPE,
        EndpointName='myendpoint'
    )

    # python 2 and 3 protobuf serialization has different precision so I'm checking
    # the version here
    if sys.version_info < (3, 0):
        assert str(result) == """result {
  classifications {
    classes {
      label: "0"
      score: 0.00128903763834
    }
    classes {
      label: "1"
      score: 0.981432199478
    }
    classes {
      label: "2"
      score: 0.0172787327319
    }
  }
}
"""
    else:
        assert str(result) == """result {