def test_predictor_classify(sagemaker_session): predictor = TensorFlowPredictor("endpoint", sagemaker_session) mock_response(json.dumps(CLASSIFY_RESPONSE).encode("utf-8"), sagemaker_session) result = predictor.classify(CLASSIFY_INPUT) assert_invoked_with_body_dict( sagemaker_session, EndpointName="endpoint", ContentType=JSON_CONTENT_TYPE, Accept=JSON_CONTENT_TYPE, CustomAttributes="tfs-method=classify", Body=json.dumps(CLASSIFY_INPUT), ) assert CLASSIFY_RESPONSE == result
def test_predictor_classify_bad_content_type(sagemaker_session): predictor = TensorFlowPredictor("endpoint", sagemaker_session, CSVSerializer()) with pytest.raises(ValueError): predictor.classify(CLASSIFY_INPUT)