def test_create_model_step(sagemaker_session):
    model = Model(
        image_uri=IMAGE_URI,
        role=ROLE,
        sagemaker_session=sagemaker_session,
    )
    inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    step = CreateModelStep(
        name="MyCreateModelStep",
        depends_on=["TestStep"],
        model=model,
        inputs=inputs,
    )
    step.add_depends_on(["SecondTestStep"])

    assert step.to_request() == {
        "Name": "MyCreateModelStep",
        "Type": "Model",
        "DependsOn": ["TestStep", "SecondTestStep"],
        "Arguments": {
            "ExecutionRoleArn": "DummyRole",
            "PrimaryContainer": {
                "Environment": {},
                "Image": "fakeimage"
            },
        },
    }
    assert step.properties.ModelName.expr == {
        "Get": "Steps.MyCreateModelStep.ModelName"
    }
示例#2
0
def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session):
    framework_model = DummyFrameworkModel(sagemaker_session)
    sparkml_model = SparkMLModel(
        model_data="s3://bucket/model_2.tar.gz",
        role=ROLE,
        sagemaker_session=sagemaker_session,
        env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"},
    )
    model = PipelineModel(models=[framework_model, sparkml_model],
                          role=ROLE,
                          sagemaker_session=sagemaker_session)
    inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    step = CreateModelStep(
        name="MyCreateModelStep",
        depends_on=["TestStep"],
        display_name="MyCreateModelStep",
        description="TestDescription",
        model=model,
        inputs=inputs,
    )
    step.add_depends_on(["SecondTestStep"])

    assert step.to_request() == {
        "Name": "MyCreateModelStep",
        "Type": "Model",
        "Description": "TestDescription",
        "DisplayName": "MyCreateModelStep",
        "DependsOn": ["TestStep", "SecondTestStep"],
        "Arguments": {
            "Containers": [
                {
                    "Environment": {
                        "SAGEMAKER_PROGRAM": "dummy_script.py",
                        "SAGEMAKER_SUBMIT_DIRECTORY":
                        "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz",
                        "SAGEMAKER_CONTAINER_LOG_LEVEL": "20",
                        "SAGEMAKER_REGION": "us-west-2",
                    },
                    "Image": "mi-1",
                    "ModelDataUrl": "s3://bucket/model_1.tar.gz",
                },
                {
                    "Environment": {
                        "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"
                    },
                    "Image":
                    "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4",
                    "ModelDataUrl": "s3://bucket/model_2.tar.gz",
                },
            ],
            "ExecutionRoleArn":
            "DummyRole",
        },
    }
    assert step.properties.ModelName.expr == {
        "Get": "Steps.MyCreateModelStep.ModelName"
    }