def test_delete_endpoint_only():
    sagemaker_session = empty_sagemaker_session()
    predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor.delete_endpoint(delete_endpoint_config=False)

    sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
    sagemaker_session.delete_endpoint_config.assert_not_called()
def test_delete_endpoint_with_config():
    sagemaker_session = empty_sagemaker_session()
    sagemaker_session.sagemaker_client.describe_endpoint = Mock(return_value={'EndpointConfigName': 'endpoint-config'})
    predictor = RealTimePredictor(ENDPOINT, sagemaker_session=sagemaker_session)
    predictor.delete_endpoint()

    sagemaker_session.delete_endpoint.assert_called_with(ENDPOINT)
    sagemaker_session.delete_endpoint_config.assert_called_with('endpoint-config')
예제 #3
0
class Predictor(object):
    def __init__(self, endpoint_name, sagemaker_session=None):
        """
        Args:
            endpoint_name (str): name of the Sagemaker endpoint
            sagemaker_session (sagemaker.session.Session): Manage interactions
                with the Amazon SageMaker APIs and any other AWS services needed.
        """
        self.endpoint_name = endpoint_name
        self._realtime_predictor = RealTimePredictor(
            endpoint_name,
            serializer=sagemaker.predictor.json_serializer,
            deserializer=sagemaker.predictor.json_deserializer,
            sagemaker_session=sagemaker_session)

    def get_action(self, obs=None):
        """Get prediction from the endpoint
        
        Args:
            obs (list/str): observation of the environment

        Returns:
            action: action to take from the prediction
            event_id: event id of the current prediction
            model_id: model id of the hosted model
            action_prob: action probability distribution
            sample_prob: sample probability distribution used for data split
        """
        payload = {}
        payload['request_type'] = "observation"
        payload['observation'] = obs
        response = self._realtime_predictor.predict(payload)
        action = response['action']
        action_prob = response['action_prob']
        event_id = response['event_id']
        model_id = response['model_id']
        sample_prob = response['sample_prob']
        return action, event_id, model_id, action_prob, sample_prob

    def get_hosted_model_id(self):
        """Return hostdd model id in the hosting endpoint
        
        Returns:
            str: model id of the model being hosted
        """
        payload = {}
        payload['request_type'] = "model_id"
        payload['observation'] = None
        response = self._realtime_predictor.predict(payload)
        model_id = response['model_id']

        return model_id

    def delete_endpoint(self):
        """Delete the Sagemaker endpoint
        """
        logger.warning(f"Deleting hosting endpoint '{self.endpoint_name}'...")
        self._realtime_predictor.delete_endpoint()
예제 #4
0
from sagemaker.amazon.amazon_estimator import get_image_uri

endpoint_name = 'creditcardfraudlogistic'

Model(
    model_data=
    's3://creditcardfraud123/logistic/output/linear-learner-191112-2119-002-ac3cc459/output/model.tar.gz',
    image=get_image_uri(region_name='us-east-1',
                        repo_name='linear-learner',
                        repo_version='latest'),
    role='AmazonSageMaker-ExecutionRole-20191005T164168').deploy(
        initial_instance_count=1,
        instance_type='ml.t2.2xlarge',
        endpoint_name=endpoint_name)

predictor = RealTimePredictor(endpoint_name)
predictor.content_type = 'text/csv'
predictor.serializer = csv_serializer
predictor.deserializer = json_deserializer

data = '83916.0,-0.46612620502545604,1.05888696127596,1.6867741713450801,-0.10791713399150099,-0.0534658672545062,-0.67078459643593,0.657296448523877,0.0267747128155009,-0.777065639315537,-0.16451379457928,1.6033857689344901,1.08437897734507,0.621801289885425,0.209210718774203,0.054395914001364995,0.30196805090530604,-0.610384355760504,-0.0111685840197793,0.22161067607904003,0.14522945875429,-0.155193422608588,-0.386047830532794,-0.019162727901044996,0.53588061095157,-0.22766218008636102,0.0387309462886897,0.266651773221212,0.114305983146032,2.58'

print(' ')
if predictor.predict(data)['predictions'][0]['predicted_label'] == 0:
    print('Not a Fraudulent Transaction')
else:
    print('Fraudulent Transaction')

predictor.delete_endpoint()
predictor.delete_model()