def test_estimator_transformer(estimator): model_data = f"s3://{BUCKET}/model.tar.gz" model_inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") estimator_transformer = EstimatorTransformer( name="EstimatorTransformerStep", estimator=estimator, model_data=model_data, model_inputs=model_inputs, instance_count=1, instance_type="ml.c4.4xlarge", transform_inputs=transform_inputs, ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 2 for request_dict in request_dicts: if request_dict["Type"] == "Model": assert request_dict == { "Name": "EstimatorTransformerStepCreateModelStep", "Type": "Model", "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { "Environment": {}, "Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz", }, }, } elif request_dict["Type"] == "Transform": assert request_dict[ "Name"] == "EstimatorTransformerStepTransformStep" arguments = request_dict["Arguments"] assert isinstance(arguments["ModelName"], Properties) arguments.pop("ModelName") assert arguments == { "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": f"s3://{BUCKET}/transform_manifest", } } }, "TransformOutput": { "S3OutputPath": None }, "TransformResources": { "InstanceCount": 1, "InstanceType": "ml.c4.4xlarge" }, } else: raise Exception( "A step exists in the collection of an invalid type.")
def test_transform_step(sagemaker_session): transformer = Transformer( model_name=MODEL_NAME, instance_count=1, instance_type="c4.4xlarge", sagemaker_session=sagemaker_session, ) inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") cache_config = CacheConfig(enable_caching=True, expire_after="PT1H") step = TransformStep( name="MyTransformStep", depends_on=["TestStep"], transformer=transformer, display_name="TransformStep", description="TestDescription", inputs=inputs, cache_config=cache_config, ) step.add_depends_on(["SecondTestStep"]) assert step.to_request() == { "Name": "MyTransformStep", "Type": "Transform", "Description": "TestDescription", "DisplayName": "TransformStep", "DependsOn": ["TestStep", "SecondTestStep"], "Arguments": { "ModelName": "gisele", "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://my-bucket/transform_manifest", } } }, "TransformOutput": { "S3OutputPath": None }, "TransformResources": { "InstanceCount": 1, "InstanceType": "c4.4xlarge", }, }, "CacheConfig": { "Enabled": True, "ExpireAfter": "PT1H" }, } assert step.properties.TransformJobName.expr == { "Get": "Steps.MyTransformStep.TransformJobName" }
def test_transform_step(sagemaker_session): transformer = Transformer( model_name=MODEL_NAME, instance_count=1, instance_type="c4.4xlarge", sagemaker_session=sagemaker_session, ) inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") step = TransformStep( name="MyTransformStep", transformer=transformer, inputs=inputs, ) assert step.to_request() == { "Name": "MyTransformStep", "Type": "Transform", "Arguments": { "ModelName": "gisele", "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://my-bucket/transform_manifest", } } }, "TransformOutput": { "S3OutputPath": None }, "TransformResources": { "InstanceCount": 1, "InstanceType": "c4.4xlarge", }, }, } assert step.properties.TransformJobName.expr == { "Get": "Steps.MyTransformStep.TransformJobName" }
def test_estimator_transformer(estimator): model_data = f"s3://{BUCKET}/model.tar.gz" model_inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) service_fault_retry_policy = StepRetryPolicy( exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10) transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") estimator_transformer = EstimatorTransformer( name="EstimatorTransformerStep", estimator=estimator, model_data=model_data, model_inputs=model_inputs, instance_count=1, instance_type="ml.c4.4xlarge", transform_inputs=transform_inputs, depends_on=["TestStep"], model_step_retry_policies=[service_fault_retry_policy], transform_step_retry_policies=[service_fault_retry_policy], repack_model_step_retry_policies=[service_fault_retry_policy], ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 2 for request_dict in request_dicts: if request_dict["Type"] == "Model": assert request_dict == { "Name": "EstimatorTransformerStepCreateModelStep", "Type": "Model", "DependsOn": ["TestStep"], "RetryPolicies": [service_fault_retry_policy.to_request()], "Arguments": { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { "Environment": {}, "Image": "fakeimage", "ModelDataUrl": "s3://my-bucket/model.tar.gz", }, }, } elif request_dict["Type"] == "Transform": assert request_dict[ "Name"] == "EstimatorTransformerStepTransformStep" assert request_dict["RetryPolicies"] == [ service_fault_retry_policy.to_request() ] arguments = request_dict["Arguments"] assert isinstance(arguments["ModelName"], Properties) arguments.pop("ModelName") assert "DependsOn" not in request_dict assert arguments == { "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": f"s3://{BUCKET}/transform_manifest", } } }, "TransformOutput": { "S3OutputPath": None }, "TransformResources": { "InstanceCount": 1, "InstanceType": "ml.c4.4xlarge" }, } else: raise Exception( "A step exists in the collection of an invalid type.")
def test_estimator_transformer_with_model_repack_with_estimator(estimator): model_data = f"s3://{BUCKET}/model.tar.gz" model_inputs = CreateModelInput( instance_type="c4.4xlarge", accelerator_type="ml.eia1.medium", ) service_fault_retry_policy = StepRetryPolicy( exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10 ) transform_inputs = TransformInput(data=f"s3://{BUCKET}/transform_manifest") estimator_transformer = EstimatorTransformer( name="EstimatorTransformerStep", estimator=estimator, model_data=model_data, model_inputs=model_inputs, instance_count=1, instance_type="ml.c4.4xlarge", transform_inputs=transform_inputs, depends_on=["TestStep"], model_step_retry_policies=[service_fault_retry_policy], transform_step_retry_policies=[service_fault_retry_policy], repack_model_step_retry_policies=[service_fault_retry_policy], entry_point=f"{DATA_DIR}/dummy_script.py", ) request_dicts = estimator_transformer.request_dicts() assert len(request_dicts) == 3 for request_dict in request_dicts: if request_dict["Type"] == "Training": assert request_dict["Name"] == "EstimatorTransformerStepRepackModel" assert request_dict["DependsOn"] == ["TestStep"] assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] arguments = request_dict["Arguments"] # pop out the dynamic generated fields arguments["HyperParameters"].pop("sagemaker_submit_directory") assert arguments == { "AlgorithmSpecification": { "TrainingInputMode": "File", "TrainingImage": "246618743249.dkr.ecr.us-west-2.amazonaws.com/" + "sagemaker-scikit-learn:0.23-1-cpu-py3", }, "OutputDataConfig": {"S3OutputPath": "s3://my-bucket/"}, "StoppingCondition": {"MaxRuntimeInSeconds": 86400}, "ResourceConfig": { "InstanceCount": 1, "InstanceType": "ml.m5.large", "VolumeSizeInGB": 30, }, "RoleArn": "DummyRole", "InputDataConfig": [ { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": "s3://my-bucket/model.tar.gz", "S3DataDistributionType": "FullyReplicated", } }, "ChannelName": "training", } ], "HyperParameters": { "inference_script": '"dummy_script.py"', "model_archive": '"s3://my-bucket/model.tar.gz"', "dependencies": "null", "source_dir": "null", "sagemaker_program": '"_repack_model.py"', "sagemaker_container_log_level": "20", "sagemaker_region": '"us-west-2"', }, "VpcConfig": {"Subnets": ["abc", "def"], "SecurityGroupIds": ["123", "456"]}, "DebugHookConfig": { "S3OutputPath": "s3://my-bucket/", "CollectionConfigurations": [], }, } elif request_dict["Type"] == "Model": assert request_dict["Name"] == "EstimatorTransformerStepCreateModelStep" assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] arguments = request_dict["Arguments"] assert isinstance(arguments["PrimaryContainer"]["ModelDataUrl"], Properties) arguments["PrimaryContainer"].pop("ModelDataUrl") assert "DependsOn" not in request_dict assert arguments == { "ExecutionRoleArn": "DummyRole", "PrimaryContainer": { "Environment": {}, "Image": "fakeimage", }, } elif request_dict["Type"] == "Transform": assert request_dict["Name"] == "EstimatorTransformerStepTransformStep" assert request_dict["RetryPolicies"] == [service_fault_retry_policy.to_request()] arguments = request_dict["Arguments"] assert isinstance(arguments["ModelName"], Properties) arguments.pop("ModelName") assert "DependsOn" not in request_dict assert arguments == { "TransformInput": { "DataSource": { "S3DataSource": { "S3DataType": "S3Prefix", "S3Uri": f"s3://{BUCKET}/transform_manifest", } } }, "TransformOutput": {"S3OutputPath": None}, "TransformResources": {"InstanceCount": 1, "InstanceType": "ml.c4.4xlarge"}, } else: raise Exception("A step exists in the collection of an invalid type.")