def test_register_model_tf(estimator_tf, model_metrics, drift_check_baselines):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    register_model = RegisterModel(
        name="RegisterModelStep",
        estimator=estimator_tf,
        model_data=model_data,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        drift_check_baselines=drift_check_baselines,
        approval_status="Approved",
        description="description",
    )
    assert ordered(register_model.request_dicts()) == ordered(
        [
            {
                "Name": "RegisterModelStep",
                "Type": "RegisterModel",
                "Description": "description",
                "Arguments": {
                    "InferenceSpecification": {
                        "Containers": [
                            {
                                "Image": "763104351884.dkr.ecr.us-west-2.amazonaws.com/tensorflow-inference:1.15.2-cpu",
                                "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
                            }
                        ],
                        "SupportedContentTypes": ["content_type"],
                        "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
                        "SupportedResponseMIMETypes": ["response_type"],
                        "SupportedTransformInstanceTypes": ["transform_instance"],
                    },
                    "ModelApprovalStatus": "Approved",
                    "ModelMetrics": {
                        "Bias": {},
                        "Explainability": {},
                        "ModelQuality": {
                            "Statistics": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/metrics.csv",
                            },
                        },
                    },
                    "DriftCheckBaselines": {
                        "ModelQuality": {
                            "Constraints": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/constraints_metrics.csv",
                            }
                        }
                    },
                    "ModelPackageDescription": "description",
                    "ModelPackageGroupName": "mpg",
                },
            },
        ]
    )
Example #2
0
def test_register_model(estimator, model_metrics):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    register_model = RegisterModel(
        name="RegisterModelStep",
        estimator=estimator,
        model_data=model_data,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        approval_status="Approved",
        description="description",
        depends_on=["TestStep"],
        tags=[{
            "Key": "myKey",
            "Value": "myValue"
        }],
    )
    assert ordered(register_model.request_dicts()) == ordered([
        {
            "Name": "RegisterModelStep",
            "Type": "RegisterModel",
            "DependsOn": ["TestStep"],
            "Arguments": {
                "InferenceSpecification": {
                    "Containers": [{
                        "Image":
                        "fakeimage",
                        "ModelDataUrl":
                        f"s3://{BUCKET}/model.tar.gz"
                    }],
                    "SupportedContentTypes": ["content_type"],
                    "SupportedRealtimeInferenceInstanceTypes":
                    ["inference_instance"],
                    "SupportedResponseMIMETypes": ["response_type"],
                    "SupportedTransformInstanceTypes": ["transform_instance"],
                },
                "ModelApprovalStatus": "Approved",
                "ModelMetrics": {
                    "ModelQuality": {
                        "Statistics": {
                            "ContentType": "text/csv",
                            "S3Uri": f"s3://{BUCKET}/metrics.csv",
                        },
                    },
                },
                "ModelPackageDescription": "description",
                "ModelPackageGroupName": "mpg",
                "Tags": [{
                    "Key": "myKey",
                    "Value": "myValue"
                }],
            },
        },
    ])
def test_register_model_with_model_repack(estimator, model_metrics):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    register_model = RegisterModel(
        name="RegisterModelStep",
        estimator=estimator,
        model_data=model_data,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        approval_status="Approved",
        description="description",
        entry_point=f"{DATA_DIR}/dummy_script.py",
        depends_on=["TestStep"],
    )

    request_dicts = register_model.request_dicts()
    assert len(request_dicts) == 2

    for request_dict in request_dicts:
        if request_dict["Type"] == "Training":
            assert request_dict["Name"] == "RegisterModelStepRepackModel"
            assert len(request_dict["DependsOn"]) == 1
            assert request_dict["DependsOn"][0] == "TestStep"
            arguments = request_dict["Arguments"]
            repacker_job_name = arguments["HyperParameters"]["sagemaker_job_name"]
            assert ordered(arguments) == ordered(
                {
                    "AlgorithmSpecification": {
                        "TrainingImage": MODEL_REPACKING_IMAGE_URI,
                        "TrainingInputMode": "File",
                    },
                    "DebugHookConfig": {
                        "CollectionConfigurations": [],
                        "S3OutputPath": f"s3://{BUCKET}/",
                    },
                    "HyperParameters": {
                        "inference_script": '"dummy_script.py"',
                        "model_archive": '"model.tar.gz"',
                        "sagemaker_submit_directory": '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
                            BUCKET, repacker_job_name.replace('"', "")
                        ),
                        "sagemaker_program": '"_repack_model.py"',
                        "sagemaker_container_log_level": "20",
                        "sagemaker_job_name": repacker_job_name,
                        "sagemaker_region": f'"{REGION}"',
                    },
                    "InputDataConfig": [
                        {
                            "ChannelName": "training",
                            "DataSource": {
                                "S3DataSource": {
                                    "S3DataDistributionType": "FullyReplicated",
                                    "S3DataType": "S3Prefix",
                                    "S3Uri": f"s3://{BUCKET}",
                                }
                            },
                        }
                    ],
                    "OutputDataConfig": {"S3OutputPath": f"s3://{BUCKET}/"},
                    "ResourceConfig": {
                        "InstanceCount": 1,
                        "InstanceType": "ml.m5.large",
                        "VolumeSizeInGB": 30,
                    },
                    "RoleArn": ROLE,
                    "StoppingCondition": {"MaxRuntimeInSeconds": 86400},
                }
            )
        elif request_dict["Type"] == "RegisterModel":
            assert request_dict["Name"] == "RegisterModelStep"
            assert "DependsOn" not in request_dict
            arguments = request_dict["Arguments"]
            assert len(arguments["InferenceSpecification"]["Containers"]) == 1
            assert (
                arguments["InferenceSpecification"]["Containers"][0]["Image"]
                == estimator.training_image_uri()
            )
            assert isinstance(
                arguments["InferenceSpecification"]["Containers"][0]["ModelDataUrl"], Properties
            )
            del arguments["InferenceSpecification"]["Containers"]
            assert ordered(arguments) == ordered(
                {
                    "InferenceSpecification": {
                        "SupportedContentTypes": ["content_type"],
                        "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
                        "SupportedResponseMIMETypes": ["response_type"],
                        "SupportedTransformInstanceTypes": ["transform_instance"],
                    },
                    "ModelApprovalStatus": "Approved",
                    "ModelMetrics": {
                        "ModelQuality": {
                            "Statistics": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/metrics.csv",
                            },
                        },
                    },
                    "ModelPackageDescription": "description",
                    "ModelPackageGroupName": "mpg",
                }
            )
        else:
            raise Exception("A step exists in the collection of an invalid type.")
def test_register_model_sip(estimator, model_metrics):
    model_list = [
        Model(image_uri="fakeimage1", model_data="Url1", env=[{"k1": "v1"}, {"k2": "v2"}]),
        Model(image_uri="fakeimage2", model_data="Url2", env=[{"k3": "v3"}, {"k4": "v4"}]),
    ]

    pipeline_model = PipelineModel(model_list, ROLE)

    register_model = RegisterModel(
        name="RegisterModelStep",
        estimator=estimator,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        approval_status="Approved",
        description="description",
        model=pipeline_model,
        depends_on=["TestStep"],
    )
    assert ordered(register_model.request_dicts()) == ordered(
        [
            {
                "Name": "RegisterModelStep",
                "Type": "RegisterModel",
                "DependsOn": ["TestStep"],
                "Arguments": {
                    "InferenceSpecification": {
                        "Containers": [
                            {
                                "Image": "fakeimage1",
                                "ModelDataUrl": "Url1",
                                "Environment": [{"k1": "v1"}, {"k2": "v2"}],
                            },
                            {
                                "Image": "fakeimage2",
                                "ModelDataUrl": "Url2",
                                "Environment": [{"k3": "v3"}, {"k4": "v4"}],
                            },
                        ],
                        "SupportedContentTypes": ["content_type"],
                        "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
                        "SupportedResponseMIMETypes": ["response_type"],
                        "SupportedTransformInstanceTypes": ["transform_instance"],
                    },
                    "ModelApprovalStatus": "Approved",
                    "ModelMetrics": {
                        "ModelQuality": {
                            "Statistics": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/metrics.csv",
                            },
                        },
                    },
                    "ModelPackageDescription": "description",
                    "ModelPackageGroupName": "mpg",
                },
            },
        ]
    )
def test_register_model_with_model_repack_with_pipeline_model(
        pipeline_model, model_metrics, drift_check_baselines):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    service_fault_retry_policy = StepRetryPolicy(
        exception_types=[StepExceptionTypeEnum.SERVICE_FAULT], max_attempts=10)
    register_model = RegisterModel(
        name="RegisterModelStep",
        model=pipeline_model,
        model_data=model_data,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        drift_check_baselines=drift_check_baselines,
        approval_status="Approved",
        description="description",
        depends_on=["TestStep"],
        repack_model_step_retry_policies=[service_fault_retry_policy],
        register_model_step_retry_policies=[service_fault_retry_policy],
        tags=[{
            "Key": "myKey",
            "Value": "myValue"
        }],
    )

    request_dicts = register_model.request_dicts()
    assert len(request_dicts) == 2

    for request_dict in request_dicts:
        if request_dict["Type"] == "Training":
            assert request_dict["Name"] == "modelNameRepackModel"
            assert len(request_dict["DependsOn"]) == 1
            assert request_dict["DependsOn"][0] == "TestStep"
            arguments = request_dict["Arguments"]
            repacker_job_name = arguments["HyperParameters"][
                "sagemaker_job_name"]
            assert ordered(arguments) == ordered({
                "AlgorithmSpecification": {
                    "TrainingImage": MODEL_REPACKING_IMAGE_URI,
                    "TrainingInputMode": "File",
                },
                "DebugHookConfig": {
                    "CollectionConfigurations": [],
                    "S3OutputPath": f"s3://{BUCKET}/",
                },
                "HyperParameters": {
                    "dependencies":
                    "null",
                    "inference_script":
                    '"dummy_script.py"',
                    "model_archive":
                    '"model.tar.gz"',
                    "sagemaker_submit_directory":
                    '"s3://{}/{}/source/sourcedir.tar.gz"'.format(
                        BUCKET, repacker_job_name.replace('"', "")),
                    "sagemaker_program":
                    '"_repack_model.py"',
                    "sagemaker_container_log_level":
                    "20",
                    "sagemaker_job_name":
                    repacker_job_name,
                    "sagemaker_region":
                    f'"{REGION}"',
                    "source_dir":
                    "null",
                },
                "InputDataConfig": [{
                    "ChannelName": "training",
                    "DataSource": {
                        "S3DataSource": {
                            "S3DataDistributionType": "FullyReplicated",
                            "S3DataType": "S3Prefix",
                            "S3Uri": f"s3://{BUCKET}",
                        }
                    },
                }],
                "OutputDataConfig": {
                    "S3OutputPath": f"s3://{BUCKET}/"
                },
                "ResourceConfig": {
                    "InstanceCount": 1,
                    "InstanceType": "ml.m5.large",
                    "VolumeSizeInGB": 30,
                },
                "RoleArn":
                ROLE,
                "StoppingCondition": {
                    "MaxRuntimeInSeconds": 86400
                },
                "Tags": [{
                    "Key": "myKey",
                    "Value": "myValue"
                }],
                "VpcConfig": [
                    ("SecurityGroupIds", ["123", "456"]),
                    ("Subnets", ["abc", "def"]),
                ],
            })
        elif request_dict["Type"] == "RegisterModel":
            assert request_dict["Name"] == "RegisterModelStep"
            assert "DependsOn" not in request_dict
            arguments = request_dict["Arguments"]
            assert len(arguments["InferenceSpecification"]["Containers"]) == 1
            assert (arguments["InferenceSpecification"]["Containers"][0]
                    ["Image"] == pipeline_model.models[0].image_uri)
            assert isinstance(
                arguments["InferenceSpecification"]["Containers"][0]
                ["ModelDataUrl"], Properties)
            del arguments["InferenceSpecification"]["Containers"]
            assert ordered(arguments) == ordered({
                "InferenceSpecification": {
                    "SupportedContentTypes": ["content_type"],
                    "SupportedRealtimeInferenceInstanceTypes":
                    ["inference_instance"],
                    "SupportedResponseMIMETypes": ["response_type"],
                    "SupportedTransformInstanceTypes": ["transform_instance"],
                },
                "ModelApprovalStatus":
                "Approved",
                "ModelMetrics": {
                    "Bias": {},
                    "Explainability": {},
                    "ModelQuality": {
                        "Statistics": {
                            "ContentType": "text/csv",
                            "S3Uri": f"s3://{BUCKET}/metrics.csv",
                        },
                    },
                },
                "DriftCheckBaselines": {
                    "ModelQuality": {
                        "Constraints": {
                            "ContentType": "text/csv",
                            "S3Uri": f"s3://{BUCKET}/constraints_metrics.csv",
                        }
                    }
                },
                "ModelPackageDescription":
                "description",
                "ModelPackageGroupName":
                "mpg",
                "Tags": [{
                    "Key": "myKey",
                    "Value": "myValue"
                }],
            })
        else:
            raise Exception(
                "A step exists in the collection of an invalid type.")
def test_register_model(estimator, model_metrics, drift_check_baselines):
    model_data = f"s3://{BUCKET}/model.tar.gz"
    register_model = RegisterModel(
        name="RegisterModelStep",
        estimator=estimator,
        model_data=model_data,
        content_types=["content_type"],
        response_types=["response_type"],
        inference_instances=["inference_instance"],
        transform_instances=["transform_instance"],
        image_uri="012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
        model_package_group_name="mpg",
        model_metrics=model_metrics,
        drift_check_baselines=drift_check_baselines,
        approval_status="Approved",
        description="description",
        display_name="RegisterModelStep",
        depends_on=["TestStep"],
        tags=[{"Key": "myKey", "Value": "myValue"}],
    )
    assert ordered(register_model.request_dicts()) == ordered(
        [
            {
                "Name": "RegisterModelStep",
                "Type": "RegisterModel",
                "DependsOn": ["TestStep"],
                "DisplayName": "RegisterModelStep",
                "Description": "description",
                "Arguments": {
                    "InferenceSpecification": {
                        "Containers": [
                            {
                                "Image": "012345678901.dkr.ecr.us-west-2.amazonaws.com/my-custom-image-uri",
                                "ModelDataUrl": f"s3://{BUCKET}/model.tar.gz",
                            }
                        ],
                        "SupportedContentTypes": ["content_type"],
                        "SupportedRealtimeInferenceInstanceTypes": ["inference_instance"],
                        "SupportedResponseMIMETypes": ["response_type"],
                        "SupportedTransformInstanceTypes": ["transform_instance"],
                    },
                    "ModelApprovalStatus": "Approved",
                    "ModelMetrics": {
                        "Bias": {},
                        "Explainability": {},
                        "ModelQuality": {
                            "Statistics": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/metrics.csv",
                            },
                        },
                    },
                    "DriftCheckBaselines": {
                        "ModelQuality": {
                            "Constraints": {
                                "ContentType": "text/csv",
                                "S3Uri": f"s3://{BUCKET}/constraints_metrics.csv",
                            }
                        }
                    },
                    "ModelPackageDescription": "description",
                    "ModelPackageGroupName": "mpg",
                    "Tags": [{"Key": "myKey", "Value": "myValue"}],
                },
            },
        ]
    )