Пример #1
0
def _check_or_create_efs(sagemaker_session):
    efs_client = sagemaker_session.boto_session.client("efs")
    file_system_exists = False
    efs_id = ""
    try:
        create_response = efs_client.create_file_system(CreationToken=EFS_CREATION_TOKEN)
        efs_id = create_response["FileSystemId"]
    except ClientError as e:
        error_code = e.response["Error"]["Code"]
        if error_code == "FileSystemAlreadyExists":
            file_system_exists = True
            logging.warning(
                "File system with given creation token %s already exists", EFS_CREATION_TOKEN
            )
        else:
            raise

    if file_system_exists:
        desc = efs_client.describe_file_systems(CreationToken=EFS_CREATION_TOKEN)
        efs_id = desc["FileSystems"][0]["FileSystemId"]
        mount_target_id = efs_client.describe_mount_targets(FileSystemId=efs_id)["MountTargets"][0][
            "MountTargetId"
        ]
        return efs_id, mount_target_id

    for _ in retries(50, "Checking EFS creating status"):
        desc = efs_client.describe_file_systems(CreationToken=EFS_CREATION_TOKEN)
        status = desc["FileSystems"][0]["LifeCycleState"]
        if status == "available":
            break

    return efs_id
Пример #2
0
def tear_down(sagemaker_session, fs_resources={}):
    try:
        if "file_system_fsx_id" in fs_resources:
            fsx_client = sagemaker_session.boto_session.client("fsx")
            fsx_client.delete_file_system(FileSystemId=fs_resources["file_system_fsx_id"])

        efs_client = sagemaker_session.boto_session.client("efs")
        if "mount_efs_target_id" in fs_resources:
            efs_client.delete_mount_target(MountTargetId=fs_resources["mount_efs_target_id"])

        if "file_system_efs_id" in fs_resources:
            for _ in retries(30, "Checking mount target deleting status"):
                desc = efs_client.describe_mount_targets(
                    FileSystemId=fs_resources["file_system_efs_id"]
                )
                if len(desc["MountTargets"]) > 0:
                    status = desc["MountTargets"][0]["LifeCycleState"]
                    if status == "deleted":
                        break
                else:
                    break

            efs_client.delete_file_system(FileSystemId=fs_resources["file_system_efs_id"])

        if "ec2_instance_id" in fs_resources:
            ec2_resource = sagemaker_session.boto_session.resource("ec2")
            _terminate_instance(ec2_resource, [fs_resources["ec2_instance_id"]])

        _delete_key_pair(sagemaker_session)

    except Exception:
        pass
Пример #3
0
def tear_down(sagemaker_session, fs_resources):
    fsx_client = sagemaker_session.boto_session.client("fsx")
    file_system_fsx_id = fs_resources.file_system_fsx_id
    fsx_client.delete_file_system(FileSystemId=file_system_fsx_id)

    efs_client = sagemaker_session.boto_session.client("efs")
    mount_efs_target_id = fs_resources.mount_efs_target_id
    efs_client.delete_mount_target(MountTargetId=mount_efs_target_id)

    file_system_efs_id = fs_resources.file_system_efs_id
    for _ in retries(30, "Checking mount target deleting status"):
        desc = efs_client.describe_mount_targets(FileSystemId=file_system_efs_id)
        if len(desc["MountTargets"]) > 0:
            status = desc["MountTargets"][0]["LifeCycleState"]
            if status == "deleted":
                break
        else:
            break

    efs_client.delete_file_system(FileSystemId=file_system_efs_id)

    ec2_resource = sagemaker_session.boto_session.resource("ec2")
    instance_id = fs_resources.ec2_instance_id
    _terminate_instance(ec2_resource, [instance_id])

    _delete_key_pair(sagemaker_session)
def _wait_for_completion(monitor):
    """Waits for the schedule to have an execution in a terminal status.

    Args:
        monitor (sagemaker.model_monitor.ModelMonitor): The monitor to watch.

    """
    for _ in retries(
            max_retry_count=200,
            exception_message_prefix=
            "Waiting for the latest execution to be in a terminal status.",
            seconds_to_sleep=60,
    ):
        schedule_desc = monitor.describe_schedule()
        execution_summary = schedule_desc.get("LastMonitoringExecutionSummary")
        last_execution_status = None

        # Once there is an execution, get its status
        if execution_summary is not None:
            last_execution_status = execution_summary[
                "MonitoringExecutionStatus"]
            # Stop the schedule as soon as it's kicked off the execution that we need from it.
            if schedule_desc["MonitoringScheduleStatus"] not in [
                    "Pending", "Stopped"
            ]:
                monitor.stop_monitoring_schedule()
        # End this loop once the execution has reached a terminal state.
        if last_execution_status in [
                "Completed", "CompletedWithViolations", "Failed", "Stopped"
        ]:
            break
Пример #5
0
def test_inference_pipeline_model_deploy_with_update_endpoint(
    sagemaker_session, cpu_instance_type, alternative_cpu_instance_type
):
    sparkml_data_path = os.path.join(DATA_DIR, "sparkml_model")
    xgboost_data_path = os.path.join(DATA_DIR, "xgboost_model")
    endpoint_name = "test-inference-pipeline-deploy-{}".format(sagemaker_timestamp())
    sparkml_model_data = sagemaker_session.upload_data(
        path=os.path.join(sparkml_data_path, "mleap_model.tar.gz"),
        key_prefix="integ-test-data/sparkml/model",
    )
    xgb_model_data = sagemaker_session.upload_data(
        path=os.path.join(xgboost_data_path, "xgb_model.tar.gz"),
        key_prefix="integ-test-data/xgboost/model",
    )

    with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        sparkml_model = SparkMLModel(
            model_data=sparkml_model_data,
            env={"SAGEMAKER_SPARKML_SCHEMA": SCHEMA},
            sagemaker_session=sagemaker_session,
        )
        xgb_image = get_image_uri(sagemaker_session.boto_region_name, "xgboost")
        xgb_model = Model(
            model_data=xgb_model_data, image=xgb_image, sagemaker_session=sagemaker_session
        )
        model = PipelineModel(
            models=[sparkml_model, xgb_model],
            role="SageMakerRole",
            sagemaker_session=sagemaker_session,
        )
        model.deploy(1, alternative_cpu_instance_type, endpoint_name=endpoint_name)
        old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=endpoint_name
        )
        old_config_name = old_endpoint["EndpointConfigName"]

        model.deploy(1, cpu_instance_type, update_endpoint=True, endpoint_name=endpoint_name)

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
        for _ in retries(40, "Waiting for 'InService' endpoint status", seconds_to_sleep=30):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=endpoint_name
            )
            if new_endpoint["EndpointStatus"] == "InService":
                break

        new_config_name = new_endpoint["EndpointConfigName"]
        new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=new_config_name
        )

        assert old_config_name != new_config_name
        assert new_config["ProductionVariants"][0]["InstanceType"] == cpu_instance_type
        assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1

    model.delete_model()
    with pytest.raises(Exception) as exception:
        sagemaker_session.sagemaker_client.describe_model(ModelName=model.name)
        assert "Could not find model" in str(exception.value)
Пример #6
0
def _assert_tags_match(sagemaker_client, resource_arn, tags, retry_count=15):
    # endpoint and training tags might take minutes to propagate.
    for _ in retries(retry_count, "Getting endpoint tags", seconds_to_sleep=30):
        actual_tags = sagemaker_client.list_tags(ResourceArn=resource_arn)["Tags"]
        if actual_tags:
            break

    assert actual_tags == tags
def test_enabling_data_capture_on_endpoint_shows_correct_data_capture_status(
    sagemaker_session, tf_full_version
):
    endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
    model_data = sagemaker_session.upload_data(
        path=os.path.join(tests.integ.DATA_DIR, "tensorflow-serving-test-model.tar.gz"),
        key_prefix="tensorflow-serving/models",
    )
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model = Model(
            model_data=model_data,
            role=ROLE,
            framework_version=tf_full_version,
            sagemaker_session=sagemaker_session,
        )
        predictor = model.deploy(
            initial_instance_count=INSTANCE_COUNT,
            instance_type=INSTANCE_TYPE,
            endpoint_name=endpoint_name,
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc.get("DataCaptureConfig") is None

        predictor.enable_data_capture()

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
        for _ in retries(
            max_retry_count=40,
            exception_message_prefix="Waiting for 'InService' endpoint status",
            seconds_to_sleep=30,
        ):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=predictor.endpoint
            )
            if new_endpoint["EndpointStatus"] == "InService":
                break

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint
        )

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"]
        )

        assert endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]
Пример #8
0
def _check_or_create_iam_profile_and_attach_role(sagemaker_session):
    if _instance_profile_exists(sagemaker_session):
        return
    iam_client = sagemaker_session.boto_session.client("iam")
    iam_client.create_instance_profile(InstanceProfileName=ROLE_NAME)
    iam_client.add_role_to_instance_profile(InstanceProfileName=ROLE_NAME, RoleName=ROLE_NAME)

    for _ in retries(30, "Checking EC2 instance profile creating status"):
        profile_info = iam_client.get_instance_profile(InstanceProfileName=ROLE_NAME)
        if profile_info["InstanceProfile"]["Roles"][0]["RoleName"] == ROLE_NAME:
            break
def test_deploy_model_with_update_endpoint(
    mxnet_training_job,
    sagemaker_session,
    mxnet_full_version,
    cpu_instance_type,
    alternative_cpu_instance_type,
):
    endpoint_name = "test-mxnet-deploy-model-{}".format(sagemaker_timestamp())

    with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        desc = sagemaker_session.sagemaker_client.describe_training_job(
            TrainingJobName=mxnet_training_job)
        model_data = desc["ModelArtifacts"]["S3ModelArtifacts"]
        script_path = os.path.join(DATA_DIR, "mxnet_mnist", "mnist.py")
        model = MXNetModel(
            model_data,
            "SageMakerRole",
            entry_point=script_path,
            py_version=PYTHON_VERSION,
            sagemaker_session=sagemaker_session,
            framework_version=mxnet_full_version,
        )
        model.deploy(1,
                     alternative_cpu_instance_type,
                     endpoint_name=endpoint_name)
        old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=endpoint_name)
        old_config_name = old_endpoint["EndpointConfigName"]

        model.deploy(1,
                     cpu_instance_type,
                     update_endpoint=True,
                     endpoint_name=endpoint_name)

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 40 retries * 30s sleeps = 20min timeout
        for _ in retries(40,
                         "Waiting for 'InService' endpoint status",
                         seconds_to_sleep=30):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=endpoint_name)
            if new_endpoint["EndpointStatus"] == "InService":
                break

        new_config_name = new_endpoint["EndpointConfigName"]
        new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=new_config_name)

        assert old_config_name != new_config_name
        assert new_config["ProductionVariants"][0][
            "InstanceType"] == cpu_instance_type
        assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1
Пример #10
0
def _create_efs(sagemaker_session):
    efs_client = sagemaker_session.boto_session.client("efs")
    create_response = efs_client.create_file_system(CreationToken=EFS_CREATION_TOKEN)
    efs_id = create_response["FileSystemId"]
    fs_resources["file_system_efs_id"] = efs_id
    for _ in retries(50, "Checking EFS creating status"):
        desc = efs_client.describe_file_systems(CreationToken=EFS_CREATION_TOKEN)
        status = desc["FileSystems"][0]["LifeCycleState"]
        if status == "available":
            break
    mount_target_id = _create_efs_mount(sagemaker_session, efs_id)

    return efs_id, mount_target_id
def container_image(sagemaker_session):
    """Create a Multi-Model image since pre-built ones are not available yet."""
    algorithm_name = unique_name_from_base("sagemaker-multimodel-integ-test")
    ecr_image = _ecr_image_uri(sagemaker_session, algorithm_name)

    ecr_client = sagemaker_session.boto_session.client("ecr")
    username, password = _ecr_login(ecr_client)

    docker_client = docker.from_env()

    # Base image pull
    base_image = "142577830533.dkr.ecr.us-east-2.amazonaws.com/ubuntu:16.04"
    docker_client.images.pull(base_image,
                              auth_config={
                                  "username": username,
                                  "password": password
                              })

    # Build and tag docker image locally
    image, build_log = docker_client.images.build(
        path=os.path.join(DATA_DIR, "multimodel", "container"),
        tag=algorithm_name,
        rm=True,
    )
    image.tag(ecr_image, tag="latest")

    # Create AWS ECR and push the local docker image to it
    _create_repository(ecr_client, algorithm_name)

    # Retry docker image push
    for _ in retries(3, "Upload docker image to ECR repo",
                     seconds_to_sleep=10):
        try:
            docker_client.images.push(ecr_image,
                                      auth_config={
                                          "username": username,
                                          "password": password
                                      })
            break
        except requests.exceptions.ConnectionError:
            # This can happen when we try to create multiple repositories in parallel, so we retry
            pass

    yield ecr_image

    # Delete repository after the multi model integration tests complete
    _delete_repository(ecr_client, algorithm_name)
Пример #12
0
def container_image(sagemaker_session):
    """ Create a Multi-Model container image for use with integration testcases
    since 1P containers supporting multiple models are not available yet"""
    region = sagemaker_session.boto_region_name
    ecr_client = sagemaker_session.boto_session.client("ecr",
                                                       region_name=region)
    sts_client = sagemaker_session.boto_session.client(
        "sts",
        region_name=region,
        endpoint_url=utils.sts_regional_endpoint(region))
    account_id = sts_client.get_caller_identity()["Account"]
    algorithm_name = "sagemaker-multimodel-integ-test-{}".format(
        sagemaker_timestamp())
    ecr_image_uri_prefix = get_ecr_image_uri_prefix(account=account_id,
                                                    region=region)
    ecr_image = "{prefix}/{algorithm_name}:latest".format(
        prefix=ecr_image_uri_prefix, algorithm_name=algorithm_name)

    # Build and tag docker image locally
    docker_client = docker.from_env()
    image, build_log = docker_client.images.build(path=os.path.join(
        DATA_DIR, "multimodel", "container"),
                                                  tag=algorithm_name,
                                                  rm=True)
    image.tag(ecr_image, tag="latest")

    # Create AWS ECR and push the local docker image to it
    _create_repository(ecr_client, algorithm_name)
    username, password = _ecr_login(ecr_client)
    # Retry docker image push
    for _ in retries(3, "Upload docker image to ECR repo",
                     seconds_to_sleep=10):
        try:
            docker_client.images.push(ecr_image,
                                      auth_config={
                                          "username": username,
                                          "password": password
                                      })
            break
        except requests.exceptions.ConnectionError:
            # This can happen when we try to create multiple repositories in parallel, so we retry
            pass

    yield ecr_image

    # Delete repository after the multi model integration tests complete
    _delete_repository(ecr_client, algorithm_name)
Пример #13
0
def _create_efs_mount(sagemaker_session, file_system_id):
    subnet_ids, security_group_ids = check_or_create_vpc_resources_efs_fsx(
        sagemaker_session, VPC_NAME
    )
    efs_client = sagemaker_session.boto_session.client("efs")
    mount_response = efs_client.create_mount_target(
        FileSystemId=file_system_id, SubnetId=subnet_ids[0], SecurityGroups=security_group_ids
    )
    mount_target_id = mount_response["MountTargetId"]

    for _ in retries(50, "Checking EFS mounting target status"):
        desc = efs_client.describe_mount_targets(MountTargetId=mount_target_id)
        status = desc["MountTargets"][0]["LifeCycleState"]
        if status == "available":
            break

    return mount_target_id
Пример #14
0
def _create_fsx(sagemaker_session):
    fsx_client = sagemaker_session.boto_session.client("fsx")
    subnet_ids, security_group_ids = check_or_create_vpc_resources_efs_fsx(
        sagemaker_session, VPC_NAME)
    create_response = fsx_client.create_file_system(
        FileSystemType="LUSTRE",
        StorageCapacity=STORAGE_CAPACITY_IN_BYTES,
        SubnetIds=[subnet_ids[0]],
        SecurityGroupIds=security_group_ids,
    )
    fsx_id = create_response["FileSystem"]["FileSystemId"]
    fs_resources["file_system_fsx_id"] = fsx_id

    for _ in retries(50, "Checking FSX creating status"):
        desc = fsx_client.describe_file_systems(FileSystemIds=[fsx_id])
        status = desc["FileSystems"][0]["Lifecycle"]
        if status == "AVAILABLE":
            break

    return fsx_id
Пример #15
0
def _delete_schedules_associated_with_endpoint(sagemaker_session,
                                               endpoint_name):
    """Deletes schedules associated with a given endpoint. Per latest validation, ensures the
    schedule is stopped and no executions are running, before deleting (otherwise latest
    server-side validations will prevent deletes).

    Args:
        sagemaker_session (sagemaker.session.Session): A SageMaker Session
            object, used for SageMaker interactions (default: None). If not
            specified, one is created using the default AWS configuration
            chain.
        endpoint_name (str): The name of the endpoint to delete schedules from.

    """
    predictor = RealTimePredictor(endpoint=endpoint_name,
                                  sagemaker_session=sagemaker_session)
    monitors = predictor.list_monitors()
    for monitor in monitors:
        try:
            monitor._wait_for_schedule_changes_to_apply()
            # Stop the schedules to prevent new executions from triggering.
            monitor.stop_monitoring_schedule()
            executions = monitor.list_executions()
            for execution in executions:
                execution.stop()
            # Wait for all executions to completely stop.
            # Schedules can't be deleted with running executions.
            for execution in executions:
                for _ in retries(60,
                                 "Waiting for executions to stop",
                                 seconds_to_sleep=5):
                    status = execution.describe()["ProcessingJobStatus"]
                    if status == "Stopped":
                        break
            # Delete schedules.
            monitor.delete_monitoring_schedule()
        except Exception as e:
            LOGGER.warning(
                "Failed to delete monitor {}".format(
                    monitor.monitoring_schedule_name), e)
Пример #16
0
def _create_ec2_instance(
    sagemaker_session,
    image_id,
    instance_type,
    key_name,
    min_count,
    max_count,
    security_group_ids,
    subnet_id,
):
    ec2_resource = sagemaker_session.boto_session.resource("ec2")
    ec2_instances = ec2_resource.create_instances(
        ImageId=image_id,
        InstanceType=instance_type,
        KeyName=key_name,
        MinCount=min_count,
        MaxCount=max_count,
        IamInstanceProfile={"Name": ROLE_NAME},
        DryRun=False,
        NetworkInterfaces=[
            {
                "SubnetId": subnet_id,
                "DeviceIndex": 0,
                "AssociatePublicIpAddress": True,
                "Groups": security_group_ids,
            }
        ],
    )

    ec2_instances[0].wait_until_running()
    ec2_instances[0].reload()
    fs_resources["ec2_instance_id"] = ec2_instances[0].id
    ec2_client = sagemaker_session.boto_session.client("ec2")
    for _ in retries(30, "Checking EC2 creation status"):
        statuses = ec2_client.describe_instance_status(InstanceIds=[ec2_instances[0].id])
        status = statuses["InstanceStatuses"][0]
        if status["InstanceStatus"]["Status"] == "ok" and status["SystemStatus"]["Status"] == "ok":
            break
    return ec2_instances[0]
Пример #17
0
def _wait_and_assert_that_no_rule_jobs_errored(training_job):
    # Wait for all rule jobs to complete.
    # Training job completion takes takes ~5min after training job ends
    # 120 retries * 10s sleeps = 20min timeout
    for _ in retries(
        max_retry_count=120,
        exception_message_prefix="Waiting for all jobs to be in success status or any to be in error",
        seconds_to_sleep=10,
    ):
        job_description = training_job.describe()
        debug_rule_evaluation_statuses = job_description.get("DebugRuleEvaluationStatuses")
        if not debug_rule_evaluation_statuses:
            break
        incomplete_rule_job_found = False
        for debug_rule_evaluation_status in debug_rule_evaluation_statuses:
            assert debug_rule_evaluation_status["RuleEvaluationStatus"] != "Error"
            if (
                debug_rule_evaluation_status["RuleEvaluationStatus"]
                not in _NON_ERROR_TERMINAL_RULE_JOB_STATUSES
            ):
                incomplete_rule_job_found = True
        if not incomplete_rule_job_found:
            break
Пример #18
0
def test_model_registration_with_tuning_model(
    sagemaker_session,
    role,
    cpu_instance_type,
    pipeline_name,
    region_name,
):
    base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
    entry_point = os.path.join(base_dir, "mnist.py")
    input_path = sagemaker_session.upload_data(
        path=os.path.join(base_dir, "training"),
        key_prefix="integ-test-data/pytorch_mnist/training",
    )
    inputs = TrainingInput(s3_data=input_path)

    instance_count = ParameterInteger(name="InstanceCount", default_value=1)
    instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")

    pytorch_estimator = PyTorch(
        entry_point=entry_point,
        role=role,
        framework_version="1.5.0",
        py_version="py3",
        instance_count=instance_count,
        instance_type=instance_type,
        sagemaker_session=sagemaker_session,
        enable_sagemaker_metrics=True,
        max_retry_attempts=3,
    )

    min_batch_size = ParameterString(name="MinBatchSize", default_value="64")
    max_batch_size = ParameterString(name="MaxBatchSize", default_value="128")
    hyperparameter_ranges = {
        "batch-size": IntegerParameter(min_batch_size, max_batch_size),
    }

    tuner = HyperparameterTuner(
        estimator=pytorch_estimator,
        objective_metric_name="test:acc",
        objective_type="Maximize",
        hyperparameter_ranges=hyperparameter_ranges,
        metric_definitions=[{"Name": "test:acc", "Regex": "Overall test accuracy: (.*?);"}],
        max_jobs=2,
        max_parallel_jobs=2,
    )

    step_tune = TuningStep(
        name="my-tuning-step",
        tuner=tuner,
        inputs=inputs,
    )

    step_register_best = RegisterModel(
        name="my-model-regis",
        estimator=pytorch_estimator,
        model_data=step_tune.get_top_model_s3_uri(
            top_k=0,
            s3_bucket=sagemaker_session.default_bucket(),
        ),
        content_types=["text/csv"],
        response_types=["text/csv"],
        inference_instances=["ml.t2.medium", "ml.m5.large"],
        transform_instances=["ml.m5.large"],
        entry_point=entry_point,
    )

    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[instance_count, instance_type, min_batch_size, max_batch_size],
        steps=[step_tune, step_register_best],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            create_arn,
        )

        for _ in retries(
            max_retry_count=5,
            exception_message_prefix="Waiting for a successful execution of pipeline",
            seconds_to_sleep=10,
        ):
            execution = pipeline.start(parameters={})
            assert re.match(
                rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
                execution.arn,
            )
            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 3
            for step in execution_steps:
                assert step["StepStatus"] == "Succeeded"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_model_registration_with_drift_check_baselines(
    sagemaker_session,
    role,
    pipeline_name,
):
    instance_count = ParameterInteger(name="InstanceCount", default_value=1)
    instance_type = ParameterString(name="InstanceType",
                                    default_value="ml.m5.xlarge")

    # upload model data to s3
    model_local_path = os.path.join(DATA_DIR, "mxnet_mnist/model.tar.gz")
    model_base_uri = "s3://{}/{}/input/model/{}".format(
        sagemaker_session.default_bucket(),
        "register_model_test_with_drift_baseline",
        utils.unique_name_from_base("model"),
    )
    model_uri = S3Uploader.upload(model_local_path,
                                  model_base_uri,
                                  sagemaker_session=sagemaker_session)
    model_uri_param = ParameterString(name="model_uri",
                                      default_value=model_uri)

    # upload metrics to s3
    metrics_data = (
        '{"regression_metrics": {"mse": {"value": 4.925353410353891, '
        '"standard_deviation": 2.219186917819692}}}')
    metrics_base_uri = "s3://{}/{}/input/metrics/{}".format(
        sagemaker_session.default_bucket(),
        "register_model_test_with_drift_baseline",
        utils.unique_name_from_base("metrics"),
    )
    metrics_uri = S3Uploader.upload_string_as_file_body(
        body=metrics_data,
        desired_s3_uri=metrics_base_uri,
        sagemaker_session=sagemaker_session,
    )
    metrics_uri_param = ParameterString(name="metrics_uri",
                                        default_value=metrics_uri)

    model_metrics = ModelMetrics(
        bias=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        explainability=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        bias_pre_training=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        bias_post_training=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
    )
    drift_check_baselines = DriftCheckBaselines(
        model_statistics=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        model_constraints=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        model_data_statistics=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        model_data_constraints=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        bias_config_file=FileSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        bias_pre_training_constraints=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        bias_post_training_constraints=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        explainability_constraints=MetricsSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
        explainability_config_file=FileSource(
            s3_uri=metrics_uri_param,
            content_type="application/json",
        ),
    )
    customer_metadata_properties = {"key1": "value1"}
    estimator = XGBoost(
        entry_point="training.py",
        source_dir=os.path.join(DATA_DIR, "sip"),
        instance_type=instance_type,
        instance_count=instance_count,
        framework_version="0.90-2",
        sagemaker_session=sagemaker_session,
        py_version="py3",
        role=role,
    )
    step_register = RegisterModel(
        name="MyRegisterModelStep",
        estimator=estimator,
        model_data=model_uri_param,
        content_types=["application/json"],
        response_types=["application/json"],
        inference_instances=["ml.t2.medium", "ml.m5.xlarge"],
        transform_instances=["ml.m5.xlarge"],
        model_package_group_name="testModelPackageGroup",
        model_metrics=model_metrics,
        drift_check_baselines=drift_check_baselines,
        customer_metadata_properties=customer_metadata_properties,
    )

    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[
            model_uri_param,
            metrics_uri_param,
            instance_type,
            instance_count,
        ],
        steps=[step_register],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        for _ in retries(
                max_retry_count=5,
                exception_message_prefix=
                "Waiting for a successful execution of pipeline",
                seconds_to_sleep=10,
        ):
            execution = pipeline.start(parameters={
                "model_uri": model_uri,
                "metrics_uri": metrics_uri
            })
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if failure_reason != "":
                logging.error(
                    f"Pipeline execution failed with error: {failure_reason}."
                    " Retrying..")
                continue
            assert execution_steps[0]["StepStatus"] == "Succeeded"
            assert execution_steps[0]["StepName"] == "MyRegisterModelStep"

            response = sagemaker_session.sagemaker_client.describe_model_package(
                ModelPackageName=execution_steps[0]["Metadata"]
                ["RegisterModel"]["Arn"])

            assert (response["ModelMetrics"]["Explainability"]["Report"]
                    ["ContentType"] == "application/json")
            assert (response["DriftCheckBaselines"]["Bias"][
                "PreTrainingConstraints"]["ContentType"] == "application/json")
            assert (response["DriftCheckBaselines"]["Explainability"]
                    ["Constraints"]["ContentType"] == "application/json")
            assert (response["DriftCheckBaselines"]["ModelQuality"]
                    ["Statistics"]["ContentType"] == "application/json")
            assert (response["DriftCheckBaselines"]["ModelDataQuality"]
                    ["Statistics"]["ContentType"] == "application/json")
            assert response[
                "CustomerMetadataProperties"] == customer_metadata_properties
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_multi_data_model_deploy_pretrained_models_update_endpoint(
        container_image, sagemaker_session, cpu_instance_type,
        alternative_cpu_instance_type):
    timestamp = sagemaker_timestamp()
    endpoint_name = "test-multimodel-endpoint-{}".format(timestamp)
    model_name = "test-multimodel-{}".format(timestamp)

    # Define pretrained model local path
    pretrained_model_data_local_path = os.path.join(DATA_DIR, "sparkml_model",
                                                    "mleap_model.tar.gz")

    with timeout_and_delete_endpoint_by_name(endpoint_name, sagemaker_session):
        model_data_prefix = os.path.join("s3://",
                                         sagemaker_session.default_bucket(),
                                         "multimodel-{}/".format(timestamp))
        multi_data_model = MultiDataModel(
            name=model_name,
            model_data_prefix=model_data_prefix,
            image=container_image,
            role=ROLE,
            sagemaker_session=sagemaker_session,
        )

        # Add model before deploy
        multi_data_model.add_model(pretrained_model_data_local_path,
                                   PRETRAINED_MODEL_PATH_1)
        # Deploy model to an endpoint
        multi_data_model.deploy(1,
                                cpu_instance_type,
                                endpoint_name=endpoint_name)
        # Add model after deploy
        multi_data_model.add_model(pretrained_model_data_local_path,
                                   PRETRAINED_MODEL_PATH_2)

        # List model assertions
        endpoint_models = []
        for model_path in multi_data_model.list_models():
            endpoint_models.append(model_path)
        assert PRETRAINED_MODEL_PATH_1 in endpoint_models
        assert PRETRAINED_MODEL_PATH_2 in endpoint_models

        predictor = RealTimePredictor(
            endpoint=endpoint_name,
            sagemaker_session=sagemaker_session,
            serializer=npy_serializer,
            deserializer=string_deserializer,
        )

        data = numpy.zeros(shape=(1, 1, 28, 28))
        result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_1)
        assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_1)

        result = predictor.predict(data, target_model=PRETRAINED_MODEL_PATH_2)
        assert result == "Invoked model: {}".format(PRETRAINED_MODEL_PATH_2)

        old_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=endpoint_name)
        old_config_name = old_endpoint["EndpointConfigName"]

        # Update endpoint
        multi_data_model.deploy(1,
                                alternative_cpu_instance_type,
                                endpoint_name=endpoint_name,
                                update_endpoint=True)

        # Wait for endpoint to finish updating
        for _ in retries(40,
                         "Waiting for 'InService' endpoint status",
                         seconds_to_sleep=30):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=endpoint_name)
            if new_endpoint["EndpointStatus"] == "InService":
                break

        new_config_name = new_endpoint["EndpointConfigName"]

        new_config = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=new_config_name)
        assert old_config_name != new_config_name
        assert new_config["ProductionVariants"][0][
            "InstanceType"] == alternative_cpu_instance_type
        assert new_config["ProductionVariants"][0]["InitialInstanceCount"] == 1

        # Cleanup
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=old_config_name)
        sagemaker_session.sagemaker_client.delete_endpoint_config(
            EndpointConfigName=new_config_name)
        multi_data_model.delete_model()
    with pytest.raises(Exception) as exception:
        sagemaker_session.sagemaker_client.describe_model(ModelName=model_name)
        assert "Could not find model" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(
            name=old_config_name)
        assert "Could not find endpoint" in str(exception.value)
        sagemaker_session.sagemaker_client.describe_endpoint_config(
            name=new_config_name)
        assert "Could not find endpoint" in str(exception.value)
def test_one_step_data_quality_pipeline_happycase(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    supplied_baseline_statistics_uri_param,
    supplied_baseline_constraints_uri_param,
    data_quality_check_config,
    data_quality_supplied_baseline_statistics,
):
    data_quality_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR, "pipeline/quality_check_step/data_quality/good_cases/constraints.json"
        ),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    data_quality_check_step = QualityCheckStep(
        name="DataQualityCheckStep",
        skip_check=False,
        register_new_baseline=False,
        quality_check_config=data_quality_check_config,
        check_job_config=check_job_config,
        supplied_baseline_statistics=supplied_baseline_statistics_uri_param,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[data_quality_check_step],
        parameters=[
            supplied_baseline_statistics_uri_param,
            supplied_baseline_constraints_uri_param,
        ],
        sagemaker_session=sagemaker_session,
    )
    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        for _ in retries(
            max_retry_count=5,
            exception_message_prefix="Waiting for a successful execution of pipeline",
            seconds_to_sleep=10,
        ):
            execution = pipeline.start(
                parameters={
                    "SuppliedBaselineStatisticsUri": data_quality_supplied_baseline_statistics,
                    "SuppliedBaselineConstraintsUri": data_quality_supplied_baseline_constraints,
                }
            )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if failure_reason != "":
                logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
                continue
            assert execution_steps[0]["StepName"] == "DataQualityCheckStep"
            assert execution_steps[0]["StepStatus"] == "Succeeded"
            data_qual_metadata = execution_steps[0]["Metadata"]["QualityCheck"]
            assert not data_qual_metadata["SkipCheck"]
            assert not data_qual_metadata["RegisterNewBaseline"]
            assert not data_qual_metadata.get("ViolationReport", "")
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckConstraints"]
                == data_quality_supplied_baseline_constraints
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckStatistics"]
                == data_quality_supplied_baseline_statistics
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckConstraints"]
                != data_qual_metadata["CalculatedBaselineConstraints"]
            )
            assert (
                data_qual_metadata["BaselineUsedForDriftCheckStatistics"]
                != data_qual_metadata["CalculatedBaselineStatistics"]
            )
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_one_step_model_quality_pipeline_constraint_violation(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    supplied_baseline_statistics_uri_param,
    supplied_baseline_constraints_uri_param,
    model_quality_check_config,
    model_quality_supplied_baseline_statistics,
):
    model_quality_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR, "pipeline/quality_check_step/model_quality/bad_cases/constraints.json"
        ),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    model_quality_check_step = QualityCheckStep(
        name="ModelQualityCheckStep",
        register_new_baseline=False,
        skip_check=False,
        quality_check_config=model_quality_check_config,
        check_job_config=check_job_config,
        supplied_baseline_statistics=supplied_baseline_statistics_uri_param,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[model_quality_check_step],
        parameters=[
            supplied_baseline_statistics_uri_param,
            supplied_baseline_constraints_uri_param,
        ],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        for _ in retries(
            max_retry_count=5,
            exception_message_prefix="Waiting for a successful execution of pipeline",
            seconds_to_sleep=10,
        ):
            execution = pipeline.start(
                parameters={
                    "SuppliedBaselineStatisticsUri": model_quality_supplied_baseline_statistics,
                    "SuppliedBaselineConstraintsUri": model_quality_supplied_baseline_constraints,
                }
            )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if _CHECK_FAIL_ERROR_MSG not in failure_reason:
                logging.error(f"Pipeline execution failed with error: {failure_reason}. Retrying..")
                continue
            assert execution_steps[0]["StepName"] == "ModelQualityCheckStep"
            assert execution_steps[0]["StepStatus"] == "Failed"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_disabling_data_capture_on_endpoint_shows_correct_data_capture_status(
        sagemaker_session, tensorflow_inference_latest_version):
    endpoint_name = unique_name_from_base("sagemaker-tensorflow-serving")
    model_data = sagemaker_session.upload_data(
        path=os.path.join(tests.integ.DATA_DIR,
                          "tensorflow-serving-test-model.tar.gz"),
        key_prefix="tensorflow-serving/models",
    )
    with tests.integ.timeout.timeout_and_delete_endpoint_by_name(
            endpoint_name, sagemaker_session):
        model = TensorFlowModel(
            model_data=model_data,
            role=ROLE,
            framework_version=tensorflow_inference_latest_version,
            sagemaker_session=sagemaker_session,
        )
        destination_s3_uri = os.path.join("s3://",
                                          sagemaker_session.default_bucket(),
                                          endpoint_name, "custom")
        predictor = model.deploy(
            initial_instance_count=INSTANCE_COUNT,
            instance_type=INSTANCE_TYPE,
            endpoint_name=endpoint_name,
            data_capture_config=DataCaptureConfig(
                enable_capture=True,
                sampling_percentage=CUSTOM_SAMPLING_PERCENTAGE,
                destination_s3_uri=destination_s3_uri,
                capture_options=CUSTOM_CAPTURE_OPTIONS,
                csv_content_types=CUSTOM_CSV_CONTENT_TYPES,
                json_content_types=CUSTOM_JSON_CONTENT_TYPES,
                sagemaker_session=sagemaker_session,
            ),
        )

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name)

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"])

        assert endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]
        assert (endpoint_config_desc["DataCaptureConfig"]
                ["InitialSamplingPercentage"] == CUSTOM_SAMPLING_PERCENTAGE)
        assert endpoint_config_desc["DataCaptureConfig"]["CaptureOptions"] == [
            {
                "CaptureMode": "Input"
            }
        ]
        assert (endpoint_config_desc["DataCaptureConfig"]
                ["CaptureContentTypeHeader"]["CsvContentTypes"] ==
                CUSTOM_CSV_CONTENT_TYPES)
        assert (endpoint_config_desc["DataCaptureConfig"]
                ["CaptureContentTypeHeader"]["JsonContentTypes"] ==
                CUSTOM_JSON_CONTENT_TYPES)

        predictor.disable_data_capture()

        # Wait for endpoint to finish updating
        # Endpoint update takes ~7min. 25 retries * 60s sleeps = 25min timeout
        for _ in retries(
                max_retry_count=25,
                exception_message_prefix=
                "Waiting for 'InService' endpoint status",
                seconds_to_sleep=60,
        ):
            new_endpoint = sagemaker_session.sagemaker_client.describe_endpoint(
                EndpointName=predictor.endpoint_name)
            if new_endpoint["EndpointStatus"] == "InService":
                break

        endpoint_desc = sagemaker_session.sagemaker_client.describe_endpoint(
            EndpointName=predictor.endpoint_name)

        endpoint_config_desc = sagemaker_session.sagemaker_client.describe_endpoint_config(
            EndpointConfigName=endpoint_desc["EndpointConfigName"])

        assert not endpoint_config_desc["DataCaptureConfig"]["EnableCapture"]
def test_training_job_with_debugger_and_profiler(
    sagemaker_session,
    pipeline_name,
    role,
    pytorch_training_latest_version,
    pytorch_training_latest_py_version,
):
    instance_count = ParameterInteger(name="InstanceCount", default_value=1)
    instance_type = ParameterString(name="InstanceType", default_value="ml.m5.xlarge")

    rules = [
        Rule.sagemaker(rule_configs.vanishing_gradient()),
        Rule.sagemaker(base_config=rule_configs.all_zero(), rule_parameters={"tensor_regex": ".*"}),
        Rule.sagemaker(rule_configs.loss_not_decreasing()),
    ]
    debugger_hook_config = DebuggerHookConfig(
        s3_output_path=(f"s3://{sagemaker_session.default_bucket()}/{uuid.uuid4()}/tensors")
    )

    base_dir = os.path.join(DATA_DIR, "pytorch_mnist")
    script_path = os.path.join(base_dir, "mnist.py")
    input_path = sagemaker_session.upload_data(
        path=os.path.join(base_dir, "training"),
        key_prefix="integ-test-data/pytorch_mnist/training",
    )
    inputs = TrainingInput(s3_data=input_path)

    pytorch_estimator = PyTorch(
        entry_point=script_path,
        role="SageMakerRole",
        framework_version=pytorch_training_latest_version,
        py_version=pytorch_training_latest_py_version,
        instance_count=instance_count,
        instance_type=instance_type,
        sagemaker_session=sagemaker_session,
        rules=rules,
        debugger_hook_config=debugger_hook_config,
    )

    step_train = TrainingStep(
        name="pytorch-train",
        estimator=pytorch_estimator,
        inputs=inputs,
    )

    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[instance_count, instance_type],
        steps=[step_train],
        sagemaker_session=sagemaker_session,
    )

    for _ in retries(
        max_retry_count=5,
        exception_message_prefix="Waiting for a successful execution of pipeline",
        seconds_to_sleep=10,
    ):
        try:
            response = pipeline.create(role)
            create_arn = response["PipelineArn"]

            execution = pipeline.start()
            response = execution.describe()
            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=10, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if failure_reason != "":
                logging.error(f"Pipeline execution failed with error: {failure_reason}.Retrying..")
                continue
            assert execution_steps[0]["StepName"] == "pytorch-train"
            assert execution_steps[0]["StepStatus"] == "Succeeded"

            training_job_arn = execution_steps[0]["Metadata"]["TrainingJob"]["Arn"]
            job_description = sagemaker_session.sagemaker_client.describe_training_job(
                TrainingJobName=training_job_arn.split("/")[1]
            )

            for index, rule in enumerate(rules):
                config = job_description["DebugRuleConfigurations"][index]
                assert config["RuleConfigurationName"] == rule.name
                assert config["RuleEvaluatorImage"] == rule.image_uri
                assert config["VolumeSizeInGB"] == 0
                assert (
                    config["RuleParameters"]["rule_to_invoke"]
                    == rule.rule_parameters["rule_to_invoke"]
                )
            assert job_description["DebugHookConfig"] == debugger_hook_config._to_request_dict()

            assert job_description["ProfilingStatus"] == "Enabled"
            assert job_description["ProfilerConfig"]["ProfilingIntervalInMilliseconds"] == 500
            break
        finally:
            try:
                pipeline.delete()
            except Exception:
                pass
def test_one_step_data_bias_pipeline_constraint_violation(
    sagemaker_session,
    role,
    pipeline_name,
    check_job_config,
    data_bias_check_config,
    supplied_baseline_constraints_uri_param,
):
    data_bias_supplied_baseline_constraints = Constraints.from_file_path(
        constraints_file_path=os.path.join(
            DATA_DIR,
            "pipeline/clarify_check_step/data_bias/bad_cases/analysis.json"),
        sagemaker_session=sagemaker_session,
    ).file_s3_uri
    data_bias_check_step = ClarifyCheckStep(
        name="DataBiasCheckStep",
        clarify_check_config=data_bias_check_config,
        check_job_config=check_job_config,
        skip_check=False,
        register_new_baseline=False,
        supplied_baseline_constraints=supplied_baseline_constraints_uri_param,
    )
    pipeline = Pipeline(
        name=pipeline_name,
        steps=[data_bias_check_step],
        parameters=[supplied_baseline_constraints_uri_param],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]
        monitoring_analysis_cfg_json = S3Downloader.read_file(
            data_bias_check_config.monitoring_analysis_config_uri,
            sagemaker_session,
        )
        monitoring_analysis_cfg = json.loads(monitoring_analysis_cfg_json)

        assert monitoring_analysis_cfg is not None and len(
            monitoring_analysis_cfg) > 0

        for _ in retries(
                max_retry_count=5,
                exception_message_prefix=
                "Waiting for a successful execution of pipeline",
                seconds_to_sleep=10,
        ):
            execution = pipeline.start(parameters={
                "SuppliedBaselineConstraintsUri":
                data_bias_supplied_baseline_constraints
            }, )
            response = execution.describe()

            assert response["PipelineArn"] == create_arn

            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 1
            failure_reason = execution_steps[0].get("FailureReason", "")
            if _CHECK_FAIL_ERROR_MSG not in failure_reason:
                logging.error(
                    f"Pipeline execution failed with error: {failure_reason}. Retrying.."
                )
                continue
            assert execution_steps[0]["StepName"] == "DataBiasCheckStep"
            assert execution_steps[0]["StepStatus"] == "Failed"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass
def test_model_registration_with_tensorflow_model_with_pipeline_model(
        sagemaker_session, role, tf_full_version, tf_full_py_version,
        pipeline_name, region_name):
    base_dir = os.path.join(DATA_DIR, "tensorflow_mnist")
    entry_point = os.path.join(base_dir, "mnist_v2.py")
    input_path = sagemaker_session.upload_data(
        path=os.path.join(base_dir, "data"),
        key_prefix="integ-test-data/tf-scriptmode/mnist/training",
    )
    inputs = TrainingInput(s3_data=input_path)

    instance_count = ParameterInteger(name="InstanceCount", default_value=1)
    instance_type = ParameterString(name="InstanceType",
                                    default_value="ml.m5.xlarge")

    tensorflow_estimator = TensorFlow(
        entry_point=entry_point,
        role=role,
        instance_count=instance_count,
        instance_type=instance_type,
        framework_version=tf_full_version,
        py_version=tf_full_py_version,
        sagemaker_session=sagemaker_session,
    )
    step_train = TrainingStep(
        name="MyTrain",
        estimator=tensorflow_estimator,
        inputs=inputs,
    )

    model = TensorFlowModel(
        entry_point=entry_point,
        framework_version="2.4",
        model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
        role=role,
        sagemaker_session=sagemaker_session,
    )

    pipeline_model = PipelineModel(name="MyModelPipeline",
                                   models=[model],
                                   role=role,
                                   sagemaker_session=sagemaker_session)

    step_register_model = RegisterModel(
        name="MyRegisterModel",
        model=pipeline_model,
        model_data=step_train.properties.ModelArtifacts.S3ModelArtifacts,
        content_types=["application/json"],
        response_types=["application/json"],
        inference_instances=["ml.t2.medium", "ml.m5.large"],
        transform_instances=["ml.m5.large"],
        model_package_group_name=f"{pipeline_name}TestModelPackageGroup",
    )

    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[
            instance_count,
            instance_type,
        ],
        steps=[step_train, step_register_model],
        sagemaker_session=sagemaker_session,
    )

    try:
        response = pipeline.create(role)
        create_arn = response["PipelineArn"]

        assert re.match(
            rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}",
            create_arn,
        )

        for _ in retries(
                max_retry_count=5,
                exception_message_prefix=
                "Waiting for a successful execution of pipeline",
                seconds_to_sleep=10,
        ):
            execution = pipeline.start(parameters={})
            assert re.match(
                rf"arn:aws:sagemaker:{region_name}:\d{{12}}:pipeline/{pipeline_name}/execution/",
                execution.arn,
            )
            try:
                execution.wait(delay=30, max_attempts=60)
            except WaiterError:
                pass
            execution_steps = execution.list_steps()

            assert len(execution_steps) == 3
            for step in execution_steps:
                assert step["StepStatus"] == "Succeeded"
            break
    finally:
        try:
            pipeline.delete()
        except Exception:
            pass