def test_predictor_classify(sagemaker_session):
    predictor = TensorFlowPredictor("endpoint", sagemaker_session)

    mock_response(json.dumps(CLASSIFY_RESPONSE).encode("utf-8"), sagemaker_session)
    result = predictor.classify(CLASSIFY_INPUT)

    assert_invoked_with_body_dict(
        sagemaker_session,
        EndpointName="endpoint",
        ContentType=JSON_CONTENT_TYPE,
        Accept=JSON_CONTENT_TYPE,
        CustomAttributes="tfs-method=classify",
        Body=json.dumps(CLASSIFY_INPUT),
    )

    assert CLASSIFY_RESPONSE == result
def test_predictor_classify_bad_content_type(sagemaker_session):
    predictor = TensorFlowPredictor("endpoint", sagemaker_session,
                                    CSVSerializer())

    with pytest.raises(ValueError):
        predictor.classify(CLASSIFY_INPUT)