コード例 #1
0
def test_predictor_type(sagemaker_session):
    sparkml = SparkMLModel(sagemaker_session=sagemaker_session,
                           model_data=MODEL_DATA,
                           role=ROLE)
    predictor = sparkml.deploy(1, TRAIN_INSTANCE_TYPE)

    assert isinstance(predictor, SparkMLPredictor)
コード例 #2
0
def test_predictor_custom_serialization(sagemaker_session):
    sparkml = SparkMLModel(sagemaker_session=sagemaker_session,
                           model_data=MODEL_DATA,
                           role=ROLE)
    custom_serializer = Mock()
    predictor = sparkml.deploy(1,
                               TRAIN_INSTANCE_TYPE,
                               serializer=custom_serializer)

    assert isinstance(predictor, SparkMLPredictor)
    assert predictor.serializer is custom_serializer
コード例 #3
0
def test_spark_ml_predict_invocation_with_target_variant(sagemaker_session):

    spark_ml_model_endpoint_name = unique_name_from_base(
        "integ-test-target-variant-sparkml")

    model_data = sagemaker_session.upload_data(
        path=SPARK_ML_MODEL_LOCAL_PATH,
        key_prefix="integ-test-data/sparkml/model")

    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
            spark_ml_model_endpoint_name, sagemaker_session):
        spark_ml_model = SparkMLModel(
            model_data=model_data,
            role=ROLE,
            sagemaker_session=sagemaker_session,
            env={"SAGEMAKER_SPARKML_SCHEMA": SPARK_ML_MODEL_SCHEMA},
        )

        predictor = spark_ml_model.deploy(
            DEFAULT_INSTANCE_COUNT,
            DEFAULT_INSTANCE_TYPE,
            endpoint_name=spark_ml_model_endpoint_name,
        )

        # Validate that no exception is raised when the target_variant is specified.
        predictor.predict(SPARK_ML_TEST_DATA,
                          target_variant=SPARK_ML_DEFAULT_VARIANT_NAME)

        with pytest.raises(Exception) as exception_info:
            predictor.predict(SPARK_ML_TEST_DATA,
                              target_variant=SPARK_ML_WRONG_VARIANT_NAME)

        assert "ValidationError" in str(exception_info.value)
        assert SPARK_ML_WRONG_VARIANT_NAME in str(exception_info.value)

        # cleanup resources
        spark_ml_model.delete_model()
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=spark_ml_model_endpoint_name)

    # Validate resource cleanup
    with pytest.raises(Exception) as exception:
        sagemaker_session.sagemaker_client.describe_model(
            ModelName=spark_ml_model.name)
        assert "Could not find model" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(
            name=spark_ml_model_endpoint_name)
        assert "Could not find endpoint" in str(exception.value)
コード例 #4
0
def test_prepare_container_def(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(
        model_data=MODEL_DATA_2,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        env={'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': 'text/csv'})
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    assert model.pipeline_container_def(INSTANCE_TYPE) == [{
        'Environment': {
            'SAGEMAKER_PROGRAM': 'blah.py',
            'SAGEMAKER_SUBMIT_DIRECTORY':
            's3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz',
            'SAGEMAKER_CONTAINER_LOG_LEVEL': '20',
            'SAGEMAKER_REGION': 'us-west-2',
            'SAGEMAKER_ENABLE_CLOUDWATCH_METRICS': 'false'
        },
        'Image':
        'mi-1',
        'ModelDataUrl':
        's3://bucket/model_1.tar.gz'
    }, {
        'Environment': {
            'SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT': 'text/csv'
        },
        'Image':
        '246618743249.dkr.ecr.us-west-2.amazonaws.com' +
        '/sagemaker-sparkml-serving:2.2',
        'ModelDataUrl':
        's3://bucket/model_2.tar.gz'
    }]
コード例 #5
0
def test_prepare_container_def(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(
        model_data=MODEL_DATA_2,
        role=ROLE,
        sagemaker_session=sagemaker_session,
        env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
    )
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    assert model.pipeline_container_def(INSTANCE_TYPE) == [
        {
            "Environment": {
                "SAGEMAKER_PROGRAM": "blah.py",
                "SAGEMAKER_SUBMIT_DIRECTORY":
                "s3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
                "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
                "SAGEMAKER_REGION": "us-west-2",
                "SAGEMAKER_ENABLE_CLOUDWATCH_METRICS": "false",
            },
            "Image": "mi-1",
            "ModelDataUrl": "s3://bucket/model_1.tar.gz",
        },
        {
            "Environment": {
                "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"
            },
            "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com" +
            "/sagemaker-sparkml-serving:2.2",
            "ModelDataUrl": "s3://bucket/model_2.tar.gz",
        },
    ]
コード例 #6
0
def test_deploy_update_endpoint(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    endpoint_name = "endpoint-name"
    sparkml_model = SparkMLModel(model_data=MODEL_DATA_2,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session)
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    model.deploy(
        instance_type=INSTANCE_TYPE,
        initial_instance_count=1,
        endpoint_name=endpoint_name,
        update_endpoint=True,
    )

    sagemaker_session.create_endpoint_config.assert_called_with(
        name=model.name,
        model_name=model.name,
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
        tags=None,
    )
    config_name = sagemaker_session.create_endpoint_config(
        name=model.name,
        model_name=model.name,
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE,
    )
    sagemaker_session.update_endpoint.assert_called_with(
        endpoint_name, config_name)
    sagemaker_session.create_endpoint.assert_not_called()
コード例 #7
0
def test_deploy_tags(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(model_data=MODEL_DATA_2,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session)
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    tags = [{"ModelName": "TestModel"}]
    model.deploy(instance_type=INSTANCE_TYPE,
                 initial_instance_count=1,
                 tags=tags)
    sagemaker_session.endpoint_from_production_variants.assert_called_with(
        name="mi-1-2017-10-10-14-14-15",
        production_variants=[{
            "InitialVariantWeight": 1,
            "ModelName": "mi-1-2017-10-10-14-14-15",
            "InstanceType": INSTANCE_TYPE,
            "InitialInstanceCount": 1,
            "VariantName": "AllTraffic",
        }],
        tags=tags,
        wait=True,
        kms_key=None,
        data_capture_config_dict=None,
    )
コード例 #8
0
def test_sparkml_model(sagemaker_session):
    sparkml = SparkMLModel(sagemaker_session=sagemaker_session,
                           model_data=MODEL_DATA,
                           role=ROLE)
    assert sparkml.image_uri == image_uris.retrieve("sparkml-serving",
                                                    REGION,
                                                    version="2.4")
コード例 #9
0
def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(
        model_data="s3://bucket/model_2.tar.gz",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
    )
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    step = CreateModelStep(
        name="MyCreateModelStep",
        depends_on=["TestStep"],
        display_name="MyCreateModelStep",
        description="TestDescription",
        model=model,
        inputs=inputs,
    )
    step.add_depends_on(["SecondTestStep"])

    assert step.to_request() == {
        "Name": "MyCreateModelStep",
        "Type": "Model",
        "Description": "TestDescription",
        "DisplayName": "MyCreateModelStep",
        "DependsOn": ["TestStep", "SecondTestStep"],
        "Arguments": {
            "Containers": [
                {
                    "Environment": {
                        "SAGEMAKER_PROGRAM": "dummy_script.py",
                        "SAGEMAKER_SUBMIT_DIRECTORY":
                        "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
                        "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
                        "SAGEMAKER_REGION": "us-west-2",
                    },
                    "Image": "mi-1",
                    "ModelDataUrl": "s3://bucket/model_1.tar.gz",
                },
                {
                    "Environment": {
                        "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"
                    },
                    "Image":
                    "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4",
                    "ModelDataUrl": "s3://bucket/model_2.tar.gz",
                },
            ],
            "ExecutionRoleArn":
            "DummyRole",
        },
    }
    assert step.properties.ModelName.expr == {
        "Get": "Steps.MyCreateModelStep.ModelName"
    }
コード例 #10
0
def test_transformer(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(model_data=MODEL_DATA_2,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session)
    model_name = "ModelName"
    model = PipelineModel(
        models=[framework_model, sparkml_model],
        role=ROLE,
        sagemaker_session=sagemaker_session,
        name=model_name,
    )

    instance_count = 55
    strategy = "MultiRecord"
    assemble_with = "Line"
    output_path = "s3://output/path"
    output_kms_key = "output:kms:key"
    accept = "application/jsonlines"
    env = {"my_key": "my_value"}
    max_concurrent_transforms = 20
    max_payload = 5
    tags = [{"my_tag": "my_value"}]
    volume_kms_key = "volume:kms:key"
    transformer = model.transformer(
        instance_type=INSTANCE_TYPE,
        instance_count=instance_count,
        strategy=strategy,
        assemble_with=assemble_with,
        output_path=output_path,
        output_kms_key=output_kms_key,
        accept=accept,
        env=env,
        max_concurrent_transforms=max_concurrent_transforms,
        max_payload=max_payload,
        tags=tags,
        volume_kms_key=volume_kms_key,
    )
    assert transformer.instance_type == INSTANCE_TYPE
    assert transformer.instance_count == instance_count
    assert transformer.strategy == strategy
    assert transformer.assemble_with == assemble_with
    assert transformer.output_path == output_path
    assert transformer.output_kms_key == output_kms_key
    assert transformer.accept == accept
    assert transformer.env == env
    assert transformer.max_concurrent_transforms == max_concurrent_transforms
    assert transformer.max_payload == max_payload
    assert transformer.tags == tags
    assert transformer.volume_kms_key == volume_kms_key
    assert transformer.model_name == model_name
コード例 #11
0
def test_deploy(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(model_data=MODEL_DATA_2,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session)
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)
    sagemaker_session.endpoint_from_production_variants.assert_called_with(
        'mi-1-2017-10-10-14-14-15', [{
            'InitialVariantWeight': 1,
            'ModelName': 'mi-1-2017-10-10-14-14-15',
            'InstanceType': INSTANCE_TYPE,
            'InitialInstanceCount': 1,
            'VariantName': 'AllTraffic'
        }], None)
コード例 #12
0
def test_network_isolation(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(model_data=MODEL_DATA_2,
                                 role=ROLE,
                                 sagemaker_session=sagemaker_session)
    model = PipelineModel(
        models=[framework_model, sparkml_model],
        role=ROLE,
        sagemaker_session=sagemaker_session,
        enable_network_isolation=True,
    )
    model.deploy(instance_type=INSTANCE_TYPE, initial_instance_count=1)

    sagemaker_session.create_model.assert_called_with(
        model.name,
        ROLE,
        [
            {
                "Image": "mi-1",
                "Environment": {
                    "SAGEMAKER_PROGRAM": "blah.py",
                    "SAGEMAKER_SUBMIT_DIRECTORY":
                    "s3://mybucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
                    "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
                    "SAGEMAKER_REGION": "us-west-2",
                },
                "ModelDataUrl": "s3://bucket/model_1.tar.gz",
            },
            {
                "Image":
                "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.2",
                "Environment": {},
                "ModelDataUrl": "s3://bucket/model_2.tar.gz",
            },
        ],
        vpc_config=None,
        enable_network_isolation=True,
    )
コード例 #13
0
def test_sparkml_model(sagemaker_session):
    sparkml = SparkMLModel(sagemaker_session=sagemaker_session, model_data=MODEL_DATA, role=ROLE)
    assert sparkml.image == registry(REGION, "sparkml-serving") + "/sagemaker-sparkml-serving:2.2"