コード例 #1
0
def main(datacapture=False):
    # Load config from environment and set required defaults
    # AWS especific
    AWS_DEFAULT_REGION = os.getenv('AWS_DEFAULT_REGION', 'eu-west-1')
    AWS_PROFILE = os.getenv('AWS_PROFILE', 'default')
    AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID', None)
    AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY', None)
    b3_session, sm_client, sm_runtime, sm_session = get_sm_session(
        region=AWS_DEFAULT_REGION,
        profile_name=AWS_PROFILE,
        aws_access_key_id=AWS_ACCESS_KEY_ID,
        aws_secret_access_key=AWS_SECRET_ACCESS_KEY
    )
    ROLE_ARN = os.getenv('AWS_ROLE', sagemaker.get_execution_role())

    MODEL_PACKAGE_GROUP_NAME = os.getenv(
        'MODEL_PACKAGE_GROUP_NAME', 'sts-sklearn-grp')
    BASE_JOB_PREFIX = os.getenv('BASE_JOB_PREFIX', 'sts')

    # define useful const's
    bucket = sm_session.default_bucket()
    endpoint_name = "{}-sklearn-{}".format(
        BASE_JOB_PREFIX,
        datetime.datetime.now().strftime("%Y%m%d%H%M")
    )
    prefix = "{}/{}".format(BASE_JOB_PREFIX, endpoint_name)
    data_capture_prefix = "{}/datacapture".format(prefix)
    s3_capture_upload_path = "s3://{}/{}".format(bucket, data_capture_prefix)
    # outputs is a dict to save to json
    outputs = dict()

    if datacapture is True:
        # if data capture was enabled output the S3 Uri for data capture
        outputs['monitor'] = {
            's3_capture_upload_path': s3_capture_upload_path
        }

    # get the last version aproved in the model package group
    model_package_arn = get_approved_package(
        MODEL_PACKAGE_GROUP_NAME, sm_client)
    _l.info(f"Latest approved model package: {model_package_arn}")
    model_info = sm_client.describe_model_package(
        ModelPackageName=model_package_arn)
    outputs['model_info'] = model_info
    model_uri = model_info.get(
        'InferenceSpecification')['Containers'][0]['ModelDataUrl']
    _l.info(f"Model data uri: {model_uri}")

    sk_model = SKLearnModel(
        model_uri,  # s3 uri for the model.tar.gz
        ROLE_ARN,   # sagemaker role to be used
        'model_loader.py',  # script to load the model
        framework_version='0.23-1'
    )

    data_capture_config=None
    if datacapture is True:
        # if data capture was enabled generated the required config
        _l.info("Enabling data capture as requested")
        _l.info(f"s3_capture_upload_path: {s3_capture_upload_path}")
        data_capture_config = DataCaptureConfig(
            enable_capture=True, sampling_percentage=100, 
            destination_s3_uri=s3_capture_upload_path,
            capture_options=["REQUEST", "RESPONSE"],
            sagemaker_session=sm_session
        )

    # Deploy the endpoint
    predictor = sk_model.deploy(
        instance_type="ml.m5.xlarge", 
        initial_instance_count=1,
        serializer=CSVSerializer(),
        deserializer=CSVDeserializer(),
        data_capture_config=data_capture_config,
        endpoint_name=endpoint_name
    )

    _l.info(f"Endpoint name: {predictor.endpoint_name}")
    outputs['endpoint'] = {
        'name': endpoint_name,
        'config_name': predictor.endpoint_name # is the same as the endpoint ?
    }
    outputs['model_info'].update({"name": sk_model.name})
    # ENDPOINT deploy done

    # save useful outputs to a file
    with open('deploymodel_out.json', 'w') as f:
        json.dump(outputs, f, default=json_default)
コード例 #2
0
from sagemaker.model_monitor import DataCaptureConfig

s3_capture_path = "s3://monitoring/xgb-churn-data"


data_capture_config = DataCaptureConfig(
    enable_capture=True, sampling_percentage=100, destination_s3_uri=s3_capture_path
)


from sagemaker.deserializers import CSVDeserializer

predictor = model.deploy(
    initial_instance_count=1,
    instance_type="ml.m4.large",
    endpoint_name="xgb-churn-monitor",
    data_capture_config=data_capture_config,
    deserializer=CSVDeserializer(),
)
コード例 #3
0
            EndpointConfigName=current_endpoint_config_name
        )
        production_variants = endpoint_config["ProductionVariants"]
        self._model_names = [d["ModelName"] for d in production_variants]
        return self._model_names

    @property
    def content_type(self):
        """The MIME type of the data sent to the inference endpoint."""
        return self.serializer.CONTENT_TYPE

    @property
    def accept(self):
        """The content type(s) that are expected from the inference endpoint."""
        return self.deserializer.ACCEPT

    @property
    def endpoint(self):
        """Deprecated attribute. Please use endpoint_name."""
        renamed_warning("The endpoint attribute")
        return self.endpoint_name


csv_serializer = deprecated_serialize(CSVSerializer(), "csv_serializer")
json_serializer = deprecated_serialize(JSONSerializer(), "json_serializer")
npy_serializer = deprecated_serialize(NumpySerializer(), "npy_serializer")
csv_deserializer = deprecated_deserialize(CSVDeserializer(), "csv_deserializer")
json_deserializer = deprecated_deserialize(JSONDeserializer(), "json_deserializer")
numpy_deserializer = deprecated_deserialize(NumpyDeserializer(), "numpy_deserializer")
RealTimePredictor = deprecated_class(Predictor, "RealTimePredictor")
コード例 #4
0
def csv_deserializer():
    return CSVDeserializer()