def test_create_model_step(sagemaker_session): model = Model( image_uri=IMAGE_URI, role=ROLE, sagemaker_session=sagemaker_session, ) inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) step = CreateModelStep( name="MyCreateModelStep", depends_on=["TestStep"], model=model, inputs=inputs, ) step.add_depends_on(["SecondTestStep"]) assert step.to_request() == { "Name": "MyCreateModelStep", "Type": "Model", "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { "Environment": {}, "Image": "fakeimage" }, }, } assert step.properties.ModelName.expr == { "Get": "Steps.MyCreateModelStep.ModelName" }
def test_create_model_step_with_model_pipeline(tfo, time, sagemaker_session): framework_model = DummyFrameworkModel(sagemaker_session) sparkml_model = SparkMLModel( model_data="s3://bucket/model_2.tar.gz", role=ROLE, sagemaker_session=sagemaker_session, env={"SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv"}, ) model = PipelineModel(models=[framework_model, sparkml_model], role=ROLE, sagemaker_session=sagemaker_session) inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) step = CreateModelStep( name="MyCreateModelStep", depends_on=["TestStep"], display_name="MyCreateModelStep", description="TestDescription", model=model, inputs=inputs, ) step.add_depends_on(["SecondTestStep"]) assert step.to_request() == { "Name": "MyCreateModelStep", "Type": "Model", "Description": "TestDescription", "DisplayName": "MyCreateModelStep", "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "Containers": [ { "Environment": { "SAGEMAKER_PROGRAM": "dummy_script.py", "SAGEMAKER_SUBMIT_DIRECTORY": "s3://my-bucket/mi-1-2017-10-10-14-14-15/sourcedir.tar.gz", "SAGEMAKER_CONTAINER_LOG_LEVEL": "20", "SAGEMAKER_REGION": "us-west-2", }, "Image": "mi-1", "ModelDataUrl": "s3://bucket/model_1.tar.gz", }, { "Environment": { "SAGEMAKER_DEFAULT_INVOCATIONS_ACCEPT": "text/csv" }, "Image": "246618743249.dkr.ecr.us-west-2.amazonaws.com/sagemaker-sparkml-serving:2.4", "ModelDataUrl": "s3://bucket/model_2.tar.gz", }, ], "ExecutionRoleArn": "DummyRole", }, } assert step.properties.ModelName.expr == { "Get": "Steps.MyCreateModelStep.ModelName" }