Exemple #1
0
def test_sagemakermodel_passes_correct_params_to_scala():

    model_image = "model-abc-123"
    model_path = S3DataPath("my-bucket", "model-abc-123")
    role_arn = "role-789"
    endpoint_instance_type = "c4.8xlarge"

    model = SageMakerModel(
        endpointInstanceType=endpoint_instance_type,
        endpointInitialInstanceCount=2,
        requestRowSerializer=ProtobufRequestRowSerializer(),
        responseRowDeserializer=KMeansProtobufResponseRowDeserializer(),
        modelImage=model_image,
        modelPath=model_path,
        modelEnvironmentVariables=None,
        modelExecutionRoleARN=role_arn,
        endpointCreationPolicy=EndpointCreationPolicy.DO_NOT_CREATE,
        sagemakerClient=SageMakerClients.create_sagemaker_client(),
        prependResultRows=False,
        namePolicy=None,
        uid="uid")

    assert model.modelImage == model_image
    assert model.modelPath.bucket == model_path.bucket
    assert model.modelExecutionRoleARN == role_arn
    assert model.endpointInstanceType == endpoint_instance_type
    assert model.existingEndpointName is None
Exemple #2
0
def test_sagemakermodel_can_do_resource_cleanup():
    endpoint_name = "my-existing-endpoint-123"
    model = SageMakerModel(
        endpointInstanceType="x1.128xlarge",
        endpointInitialInstanceCount=2,
        requestRowSerializer=ProtobufRequestRowSerializer(),
        responseRowDeserializer=KMeansProtobufResponseRowDeserializer(),
        existingEndpointName=endpoint_name,
        modelImage="some_image",
        modelPath=S3DataPath("a", "b"),
        modelEnvironmentVariables=None,
        modelExecutionRoleARN="role",
        endpointCreationPolicy=EndpointCreationPolicy.DO_NOT_CREATE,
        sagemakerClient=SageMakerClients.create_sagemaker_client(),
        prependResultRows=False,
        namePolicy=None,
        uid="uid")

    sm = model.sagemakerClient
    assert sm is not None

    resource_cleanup = SageMakerResourceCleanup(sm)
    assert resource_cleanup is not None

    created_resources = model.getCreatedResources()
    assert created_resources is not None

    resource_cleanup.deleteResources(created_resources)
def test_linearLearnerBinaryClassifier_passes_correct_params_to_scala():

    training_instance_type = "c4.8xlarge"
    training_instance_count = 3
    endpoint_instance_type = "c4.8xlarge"
    endpoint_initial_instance_count = 3

    training_bucket = "random-bucket"
    input_prefix = "linear-learner-binary-classifier-training"
    output_prefix = "linear-learner-binary-classifier-out"
    integTestingRole = "arn:aws:iam::123456789:role/SageMakerRole"

    estimator = LinearLearnerBinaryClassifier(
        trainingInstanceType=training_instance_type,
        trainingInstanceCount=training_instance_count,
        endpointInstanceType=endpoint_instance_type,
        endpointInitialInstanceCount=endpoint_initial_instance_count,
        sagemakerRole=IAMRole(integTestingRole),
        requestRowSerializer=ProtobufRequestRowSerializer(),
        responseRowDeserializer=
        LinearLearnerBinaryClassifierProtobufResponseRowDeserializer(),
        trainingInstanceVolumeSizeInGB=2048,
        trainingInputS3DataPath=S3DataPath(training_bucket, input_prefix),
        trainingOutputS3DataPath=S3DataPath(training_bucket, output_prefix),
        trainingMaxRuntimeInSeconds=1,
        endpointCreationPolicy=EndpointCreationPolicy.CREATE_ON_TRANSFORM,
        sagemakerClient=SageMakerClients.create_sagemaker_client(),
        s3Client=SageMakerClients.create_s3_default_client(),
        stsClient=SageMakerClients.create_sts_default_client(),
        modelPrependInputRowsToTransformationRows=True,
        namePolicyFactory=RandomNamePolicyFactory(),
        uid="sagemaker")

    assert estimator.trainingInputS3DataPath.bucket == training_bucket
    assert estimator.trainingInputS3DataPath.objectPath == input_prefix
    assert estimator.trainingInstanceCount == training_instance_count
    assert estimator.trainingInstanceType == training_instance_type
    assert estimator.endpointInstanceType == endpoint_instance_type
    assert estimator.endpointInitialInstanceCount == endpoint_initial_instance_count
    assert estimator.trainingInstanceVolumeSizeInGB == 2048
    assert estimator.trainingMaxRuntimeInSeconds == 1
    assert estimator.trainingKmsKeyId is None
Exemple #4
0
def test_sagemakermodel_can_be_created_from_java_obj():
    endpoint_name = "my-existing-endpoint-123"
    model = SageMakerModel(
        endpointInstanceType="x1.128xlarge",
        endpointInitialInstanceCount=2,
        requestRowSerializer=ProtobufRequestRowSerializer(),
        responseRowDeserializer=KMeansProtobufResponseRowDeserializer(),
        existingEndpointName=endpoint_name,
        modelImage="some_image",
        modelPath=S3DataPath("a", "b"),
        modelEnvironmentVariables=None,
        modelExecutionRoleARN="role",
        endpointCreationPolicy=EndpointCreationPolicy.DO_NOT_CREATE,
        sagemakerClient=SageMakerClients.create_sagemaker_client(),
        prependResultRows=False,
        namePolicy=None,
        uid="uid")

    new_model = SageMakerModel._from_java(model._to_java())
    assert new_model.uid == model.uid
Exemple #5
0
    endpointInstanceType=None,  # Required
    endpointInitialInstanceCount=None,  # Required
    requestRowSerializer=ProtobufRequestRowSerializer(
        featuresColumnName="features"),  # Optional: already default value
    responseRowDeserializer=
    KMeansProtobufResponseRowDeserializer(  # Optional: already default values
        distance_to_cluster_column_name="distance_to_cluster",
        closest_cluster_column_name="closest_cluster"))

transformedData2 = attachedModel.transform(testData)
transformedData2.show()

#Create model and endpoint from model data
from sagemaker_pyspark import S3DataPath

MODEL_S3_PATH = S3DataPath(initialModel.modelPath.bucket,
                           initialModel.modelPath.objectPath)
MODEL_ROLE_ARN = initialModel.modelExecutionRoleARN
MODEL_IMAGE_PATH = initialModel.modelImage

print(MODEL_S3_PATH.bucket + MODEL_S3_PATH.objectPath)
print(MODEL_ROLE_ARN)
print(MODEL_IMAGE_PATH)

from sagemaker_pyspark import RandomNamePolicy

retrievedModel = SageMakerModel(
    modelPath=MODEL_S3_PATH,
    modelExecutionRoleARN=MODEL_ROLE_ARN,
    modelImage=MODEL_IMAGE_PATH,
    endpointInstanceType="ml.t2.medium",
    endpointInitialInstanceCount=1,