def test_estimator_transformer(estimator):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    model_inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
    estimator_transformer = EstimatorTransformer(
        name="EstimatorTransformerStep",
        estimator=estimator,
        model_data=model_data,
        model_inputs=model_inputs,
        instance_count=1,
        instance_type="ml.c4.4xlarge",
        transform_inputs=transform_inputs,
    )
    request_dicts = estimator_transformer.request_dicts()
    assert len(request_dicts) == 2
    for request_dict in request_dicts:
        if request_dict["Type"] == "Model":
            assert request_dict == {
                "Name": "EstimatorTransformerStepCreateModelStep",
                "Type": "Model",
                "Arguments": {
                    "ExecutionRoleArn": "DummyRole",
                    "PrimaryContainer": {
                        "Environment": {},
                        "Image": "fakeimage",
                        "ModelDataUrl": "s3://my-bucket/model.tar.gz",
                    },
                },
            }
        elif request_dict["Type"] == "Transform":
            assert request_dict[
                "Name"] == "EstimatorTransformerStepTransformStep"
            arguments = request_dict["Arguments"]
            assert isinstance(arguments["ModelName"], Properties)
            arguments.pop("ModelName")
            assert arguments == {
                "TransformInput": {
                    "DataSource": {
                        "S3DataSource": {
                            "S3DataType": "S3Prefix",
                            "S3Uri": f"s3://{BUCKET}/transform_manifest",
                        }
                    }
                },
                "TransformOutput": {
                    "S3OutputPath": None
                },
                "TransformResources": {
                    "InstanceCount": 1,
                    "InstanceType": "ml.c4.4xlarge"
                },
            }
        else:
            raise Exception(
                "A step exists in the collection of an invalid type.")
Example #2
0
def test_transform_step(sagemaker_session):
    transformer = Transformer(
        model_name=MODEL_NAME,
        instance_count=1,
        instance_type="c4.4xlarge",
        sagemaker_session=sagemaker_session,
    )
    inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
    cache_config = CacheConfig(enable_caching=True, expire_after="PT1H")
    step = TransformStep(
        name="MyTransformStep",
        depends_on=["TestStep"],
        transformer=transformer,
        display_name="TransformStep",
        description="TestDescription",
        inputs=inputs,
        cache_config=cache_config,
    )
    step.add_depends_on(["SecondTestStep"])
    assert step.to_request() == {
        "Name": "MyTransformStep",
        "Type": "Transform",
        "Description": "TestDescription",
        "DisplayName": "TransformStep",
        "DependsOn": ["TestStep", "SecondTestStep"],
        "Arguments": {
            "ModelName": "gisele",
            "TransformInput": {
                "DataSource": {
                    "S3DataSource": {
                        "S3DataType": "S3Prefix",
                        "S3Uri": "s3://my-bucket/transform_manifest",
                    }
                }
            },
            "TransformOutput": {
                "S3OutputPath": None
            },
            "TransformResources": {
                "InstanceCount": 1,
                "InstanceType": "c4.4xlarge",
            },
        },
        "CacheConfig": {
            "Enabled": True,
            "ExpireAfter": "PT1H"
        },
    }
    assert step.properties.TransformJobName.expr == {
        "Get": "Steps.MyTransformStep.TransformJobName"
    }
def test_transform_step(sagemaker_session):
    transformer = Transformer(
        model_name=MODEL_NAME,
        instance_count=1,
        instance_type="c4.4xlarge",
        sagemaker_session=sagemaker_session,
    )
    inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
    step = TransformStep(
        name="MyTransformStep",
        transformer=transformer,
        inputs=inputs,
    )
    assert step.to_request() == {
        "Name": "MyTransformStep",
        "Type": "Transform",
        "Arguments": {
            "ModelName": "gisele",
            "TransformInput": {
                "DataSource": {
                    "S3DataSource": {
                        "S3DataType": "S3Prefix",
                        "S3Uri": "s3://my-bucket/transform_manifest",
                    }
                }
            },
            "TransformOutput": {
                "S3OutputPath": None
            },
            "TransformResources": {
                "InstanceCount": 1,
                "InstanceType": "c4.4xlarge",
            },
        },
    }
    assert step.properties.TransformJobName.expr == {
        "Get": "Steps.MyTransformStep.TransformJobName"
    }
def test_estimator_transformer(estimator):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    model_inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    service_fault_retry_policy = StepRetryPolicy(
        exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10)
    transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
    estimator_transformer = EstimatorTransformer(
        name="EstimatorTransformerStep",
        estimator=estimator,
        model_data=model_data,
        model_inputs=model_inputs,
        instance_count=1,
        instance_type="ml.c4.4xlarge",
        transform_inputs=transform_inputs,
        depends_on=["TestStep"],
        model_step_retry_policies=[service_fault_retry_policy],
        transform_step_retry_policies=[service_fault_retry_policy],
        repack_model_step_retry_policies=[service_fault_retry_policy],
    )
    request_dicts = estimator_transformer.request_dicts()
    assert len(request_dicts) == 2

    for request_dict in request_dicts:
        if request_dict["Type"] == "Model":
            assert request_dict == {
                "Name": "EstimatorTransformerStepCreateModelStep",
                "Type": "Model",
                "DependsOn": ["TestStep"],
                "RetryPolicies": [service_fault_retry_policy.to_request()],
                "Arguments": {
                    "ExecutionRoleArn": "DummyRole",
                    "PrimaryContainer": {
                        "Environment": {},
                        "Image": "fakeimage",
                        "ModelDataUrl": "s3://my-bucket/model.tar.gz",
                    },
                },
            }
        elif request_dict["Type"] == "Transform":
            assert request_dict[
                "Name"] == "EstimatorTransformerStepTransformStep"
            assert request_dict["RetryPolicies"] == [
                service_fault_retry_policy.to_request()
            ]
            arguments = request_dict["Arguments"]
            assert isinstance(arguments["ModelName"], Properties)
            arguments.pop("ModelName")
            assert "DependsOn" not in request_dict
            assert arguments == {
                "TransformInput": {
                    "DataSource": {
                        "S3DataSource": {
                            "S3DataType": "S3Prefix",
                            "S3Uri": f"s3://{BUCKET}/transform_manifest",
                        }
                    }
                },
                "TransformOutput": {
                    "S3OutputPath": None
                },
                "TransformResources": {
                    "InstanceCount": 1,
                    "InstanceType": "ml.c4.4xlarge"
                },
            }
        else:
            raise Exception(
                "A step exists in the collection of an invalid type.")
def test_estimator_transformer_with_model_repack_with_estimator(estimator):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    model_inputs = CreateModelInput(
        instance_type="c4.4xlarge",
        accelerator_type="ml.eia1.medium",
    )
    service_fault_retry_policy = StepRetryPolicy(
        exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10
    )
    transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest")
    estimator_transformer = EstimatorTransformer(
        name="EstimatorTransformerStep",
        estimator=estimator,
        model_data=model_data,
        model_inputs=model_inputs,
        instance_count=1,
        instance_type="ml.c4.4xlarge",
        transform_inputs=transform_inputs,
        depends_on=["TestStep"],
        model_step_retry_policies=[service_fault_retry_policy],
        transform_step_retry_policies=[service_fault_retry_policy],
        repack_model_step_retry_policies=[service_fault_retry_policy],
        entry_point=f"{DATA_DIR}/dummy_script.py",
    )
    request_dicts = estimator_transformer.request_dicts()
    assert len(request_dicts) == 3

    for request_dict in request_dicts:
        if request_dict["Type"] == "Training":
            assert request_dict["Name"] == "EstimatorTransformerStepRepackModel"
            assert request_dict["DependsOn"] == ["TestStep"]
            assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
            arguments = request_dict["Arguments"]
            # pop out the dynamic generated fields
            arguments["HyperParameters"].pop("sagemaker_submit_directory")
            assert arguments == {
                "AlgorithmSpecification": {
                    "TrainingInputMode": "File",
                    "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/"
                    + "sagemaker-scikit-learn:0.23-1-cpu-py3",
                },
                "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"},
                "StoppingCondition": {"MaxRuntimeInSeconds": 86400},
                "ResourceConfig": {
                    "InstanceCount": 1,
                    "InstanceType": "ml.m5.large",
                    "VolumeSizeInGB": 30,
                },
                "RoleArn": "DummyRole",
                "InputDataConfig": [
                    {
                        "DataSource": {
                            "S3DataSource": {
                                "S3DataType": "S3Prefix",
                                "S3Uri": "s3://my-bucket/model.tar.gz",
                                "S3DataDistributionType": "FullyReplicated",
                            }
                        },
                        "ChannelName": "training",
                    }
                ],
                "HyperParameters": {
                    "inference_script": '"dummy_script.py"',
                    "model_archive": '"s3://my-bucket/model.tar.gz"',
                    "dependencies": "null",
                    "source_dir": "null",
                    "sagemaker_program": '"_repack_model.py"',
                    "sagemaker_container_log_level": "20",
                    "sagemaker_region": '"us-west-2"',
                },
                "VpcConfig": {"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]},
                "DebugHookConfig": {
                    "S3OutputPath": "s3://my-bucket/",
                    "CollectionConfigurations": [],
                },
            }
        elif request_dict["Type"] == "Model":
            assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep"
            assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
            arguments = request_dict["Arguments"]
            assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties)
            arguments["PrimaryContainer"].pop("ModelDataUrl")
            assert "DependsOn" not in request_dict
            assert arguments == {
                "ExecutionRoleArn": "DummyRole",
                "PrimaryContainer": {
                    "Environment": {},
                    "Image": "fakeimage",
                },
            }
        elif request_dict["Type"] == "Transform":
            assert request_dict["Name"] == "EstimatorTransformerStepTransformStep"
            assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()]
            arguments = request_dict["Arguments"]
            assert isinstance(arguments["ModelName"], Properties)
            arguments.pop("ModelName")
            assert "DependsOn" not in request_dict
            assert arguments == {
                "TransformInput": {
                    "DataSource": {
                        "S3DataSource": {
                            "S3DataType": "S3Prefix",
                            "S3Uri": f"s3://{BUCKET}/transform_manifest",
                        }
                    }
                },
                "TransformOutput": {"S3OutputPath": None},
                "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"},
            }
        else:
            raise Exception("A step exists in the collection of an invalid type.")