def __init__(self, predictorEndpointName): print('Instantiating ObjectDetector') if predictorEndpointName: self.predictor = sagemaker.RealTimePredictor(predictorEndpointName) self.set_class_labels() else: print('Must Supply an Endpoint Name')
def test_predict_jsons(tfs_predictor): input_data = '[1.0, 2.0, 5.0]\n[1.0, 2.0, 5.0]' expected_result = {'predictions': [[3.5, 4.0, 5.5], [3.5, 4.0, 5.5]]} predictor = sagemaker.RealTimePredictor( tfs_predictor.endpoint, tfs_predictor.sagemaker_session, serializer=None, deserializer=sagemaker.predictor.json_deserializer, content_type='application/jsons', accept='application/jsons') result = predictor.predict(input_data) assert expected_result == result
def predict_wrapper(endpoint, session): return sagemaker.RealTimePredictor( endpoint, session, serializer=sagemaker.predictor.csv_serializer)