def test_classification_request_pb(sagemaker_session): request = classification_pb2.ClassificationRequest() request.model_spec.name = "generic_model" request.model_spec.signature_name = DEFAULT_SERVING_SIGNATURE_DEF_KEY example = request.input.example_list.examples.add() example.features.feature[PREDICT_INPUTS].float_list.value.extend([6.4, 3.2, 4.5, 1.5]) predictor = RealTimePredictor(sagemaker_session=sagemaker_session, endpoint=ENDPOINT, deserializer=tf_deserializer, serializer=tf_serializer) expected_response = classification_pb2.ClassificationResponse() classes = expected_response.result.classifications.add().classes class_0 = classes.add() class_0.label = "0" class_0.score = 0.00128903763834 class_1 = classes.add() class_1.label = "1" class_1.score = 0.981432199478 class_2 = classes.add() class_2.label = "2" class_2.score = 0.0172787327319 mock_response(expected_response.SerializeToString(), sagemaker_session, PROTO_CONTENT_TYPE) result = predictor.predict(request) sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with( Accept=PROTO_CONTENT_TYPE, Body=request.SerializeToString(), ContentType=PROTO_CONTENT_TYPE, EndpointName='myendpoint' ) # python 2 and 3 protobuf serialization has different precision so I'm checking # the version here if sys.version_info < (3, 0): assert str(result) == """result { classifications { classes { label: "0" score: 0.00128903763834 } classes { label: "1" score: 0.981432199478 } classes { label: "2" score: 0.0172787327319 } } } """ else: assert str(result) == """result {
def test_classification_request_csv(sagemaker_session): data = [1, 2, 3] predictor = RealTimePredictor(serializer=tf_csv_serializer, deserializer=tf_deserializer, sagemaker_session=sagemaker_session, endpoint=ENDPOINT) expected_response = json_format.Parse( json.dumps(CLASSIFICATION_RESPONSE), classification_pb2.ClassificationResponse() ).SerializeToString() mock_response(expected_response, sagemaker_session, PROTO_CONTENT_TYPE) result = predictor.predict(data) sagemaker_session.sagemaker_runtime_client.invoke_endpoint.assert_called_once_with( Accept=PROTO_CONTENT_TYPE, Body='1,2,3', ContentType=CSV_CONTENT_TYPE, EndpointName='myendpoint' ) # python 2 and 3 protobuf serialization has different precision so I'm checking # the version here if sys.version_info < (3, 0): assert str(result) == """result { classifications { classes { label: "0" score: 0.00128903763834 } classes { label: "1" score: 0.981432199478 } classes { label: "2" score: 0.0172787327319 } } } """ else: assert str(result) == """result {