コード例 #1
0
def _get_processing_inputs_with_all_parameters(bucket):
    return [
        ProcessingInput(
            source=f"s3://{bucket}",
            destination="/opt/ml/processing/input/data/",
            input_name="my_dataset",
        ),
        ProcessingInput(
            input_name="s3_input",
            s3_input=S3Input(
                s3_uri=f"s3://{bucket}",
                local_path="/opt/ml/processing/input/s3_input",
                s3_data_type="S3Prefix",
                s3_input_mode="File",
                s3_data_distribution_type="FullyReplicated",
                s3_compression_type="None",
            ),
        ),
        ProcessingInput(
            input_name="redshift_dataset_definition",
            app_managed=True,
            dataset_definition=DatasetDefinition(
                local_path="/opt/ml/processing/input/rdd",
                data_distribution_type="FullyReplicated",
                input_mode="File",
                redshift_dataset_definition=RedshiftDatasetDefinition(
                    cluster_id="integ-test-cluster-prod-us-west-2",
                    database="dev",
                    db_user="******",
                    query_string="SELECT * FROM shoes",
                    cluster_role_arn=
                    "arn:aws:iam::037210630505:role/RedshiftClusterRole-prod-us-west-2",
                    output_s3_uri=f"s3://{bucket}/rdd",
                    output_format="CSV",
                    output_compression="None",
                ),
            ),
        ),
        ProcessingInput(
            input_name="athena_dataset_definition",
            app_managed=True,
            dataset_definition=DatasetDefinition(
                local_path="/opt/ml/processing/input/add",
                data_distribution_type="FullyReplicated",
                input_mode="File",
                athena_dataset_definition=AthenaDatasetDefinition(
                    catalog="AwsDataCatalog",
                    database="default",
                    work_group="workgroup",
                    query_string=
                    'SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";',
                    output_s3_uri=f"s3://{bucket}/add",
                    output_format="JSON",
                    output_compression="GZIP",
                ),
            ),
        ),
    ]
コード例 #2
0
    def from_processing_name(cls, sagemaker_session, processing_job_name):
        """Initializes a ``ProcessingJob`` from a processing job name.

        Args:
            processing_job_name (str): Name of the processing job.
            sagemaker_session (:class:`~sagemaker.session.Session`):
                Session object which manages interactions with Amazon SageMaker and
                any other AWS services needed. If not specified, the processor creates
                one using the default AWS configuration chain.

        Returns:
            :class:`~sagemaker.processing.ProcessingJob`: The instance of ``ProcessingJob`` created
                from the job name.
        """
        job_desc = sagemaker_session.describe_processing_job(job_name=processing_job_name)

        inputs = None
        if job_desc.get("ProcessingInputs"):
            inputs = [
                ProcessingInput(
                    input_name=processing_input["InputName"],
                    s3_input=S3Input.from_boto(processing_input.get("S3Input")),
                    dataset_definition=DatasetDefinition.from_boto(
                        processing_input.get("DatasetDefinition")
                    ),
                    app_managed=processing_input.get("AppManaged", False),
                )
                for processing_input in job_desc["ProcessingInputs"]
            ]

        outputs = None
        if job_desc.get("ProcessingOutputConfig") and job_desc["ProcessingOutputConfig"].get(
            "Outputs"
        ):
            outputs = []
            for processing_output_dict in job_desc["ProcessingOutputConfig"]["Outputs"]:
                processing_output = ProcessingOutput(
                    output_name=processing_output_dict["OutputName"],
                    app_managed=processing_output_dict.get("AppManaged", False),
                    feature_store_output=FeatureStoreOutput.from_boto(
                        processing_output_dict.get("FeatureStoreOutput")
                    ),
                )

                if "S3Output" in processing_output_dict:
                    processing_output.source = processing_output_dict["S3Output"]["LocalPath"]
                    processing_output.destination = processing_output_dict["S3Output"]["S3Uri"]

                outputs.append(processing_output)
        output_kms_key = None
        if job_desc.get("ProcessingOutputConfig"):
            output_kms_key = job_desc["ProcessingOutputConfig"].get("KmsKeyId")

        return cls(
            sagemaker_session=sagemaker_session,
            job_name=processing_job_name,
            inputs=inputs,
            outputs=outputs,
            output_kms_key=output_kms_key,
        )
コード例 #3
0
def athena_dataset_definition(sagemaker_session):
    return DatasetDefinition(
        local_path="/opt/ml/processing/input/add",
        data_distribution_type="FullyReplicated",
        input_mode="File",
        athena_dataset_definition=AthenaDatasetDefinition(
            catalog="AwsDataCatalog",
            database="default",
            work_group="workgroup",
            query_string='SELECT * FROM "default"."s3_test_table_$STAGE_$REGIONUNDERSCORED";',
            output_s3_uri=f"s3://{sagemaker_session.default_bucket()}/add",
            output_format="JSON",
            output_compression="GZIP",
        ),
    )
コード例 #4
0
def create_athena_processing_input(athena_dataset_defintion, name, base_dir):
    """Create an Athena processing input for a DW job

    (From Data Wrangler Job notebook template 2021-03-10)
    """
    return ProcessingInput(
        input_name=name,
        dataset_definition=DatasetDefinition(
            local_path=f"{base_dir}/{name}",
            data_distribution_type="FullyReplicated",
            athena_dataset_definition=AthenaDatasetDefinition(
                catalog=athena_dataset_defintion["catalogName"],
                database=athena_dataset_defintion["databaseName"],
                query_string=athena_dataset_defintion["queryString"],
                output_s3_uri=athena_dataset_defintion["s3OutputLocation"] +
                f"{name}/",
                output_format=athena_dataset_defintion["outputFormat"].upper(
                ))))
コード例 #5
0
def create_redshift_processing_input(redshift_dataset_defintion, name,
                                     base_dir):
    """Create a Redshift processing input for a DW job

    (From Data Wrangler Job notebook template 2021-03-10)
    """
    return ProcessingInput(
        input_name=name,
        dataset_definition=DatasetDefinition(
            local_path=f"{base_dir}/{name}",
            data_distribution_type="FullyReplicated",
            redshift_dataset_definition=RedshiftDatasetDefinition(
                cluster_id=redshift_dataset_defintion["clusterIdentifier"],
                database=redshift_dataset_defintion["database"],
                db_user=redshift_dataset_defintion["dbUser"],
                query_string=redshift_dataset_defintion["queryString"],
                cluster_role_arn=redshift_dataset_defintion["unloadIamRole"],
                output_s3_uri=redshift_dataset_defintion["s3OutputLocation"] +
                f"{name}/",
                output_format=redshift_dataset_defintion["outputFormat"].upper(
                ))))
コード例 #6
0
    def _to_request_dict(self):
        """Generates a request dictionary using the parameters provided to the class."""

        # Create the request dictionary.
        s3_input_request = {"InputName": self.input_name, "AppManaged": self.app_managed}

        if self.s3_input:
            # Check the compression type, then add it to the dictionary.
            if (
                self.s3_input.s3_compression_type == "Gzip"
                and self.s3_input.s3_input_mode != "Pipe"
            ):
                raise ValueError("Data can only be gzipped when the input mode is Pipe.")

            s3_input_request["S3Input"] = S3Input.to_boto(self.s3_input)

        if self.dataset_definition is not None:
            s3_input_request["DatasetDefinition"] = DatasetDefinition.to_boto(
                self.dataset_definition
            )

        # Return the request dictionary.
        return s3_input_request
コード例 #7
0
def _get_data_inputs_all_parameters():
    return [
        ProcessingInput(
            source="s3://path/to/my/dataset/census.csv",
            destination="/container/path/",
            input_name="my_dataset",
            s3_data_type="S3Prefix",
            s3_input_mode="File",
            s3_data_distribution_type="FullyReplicated",
            s3_compression_type="None",
        ),
        ProcessingInput(
            input_name="s3_input",
            s3_input=S3Input(
                s3_uri="s3://path/to/my/dataset/census.csv",
                local_path="/container/path/",
                s3_data_type="S3Prefix",
                s3_input_mode="File",
                s3_data_distribution_type="FullyReplicated",
                s3_compression_type="None",
            ),
        ),
        ProcessingInput(
            input_name="redshift_dataset_definition",
            app_managed=True,
            dataset_definition=DatasetDefinition(
                local_path="/opt/ml/processing/input/dd",
                data_distribution_type="FullyReplicated",
                input_mode="File",
                redshift_dataset_definition=RedshiftDatasetDefinition(
                    cluster_id="cluster_id",
                    database="database",
                    db_user="******",
                    query_string="query_string",
                    cluster_role_arn="cluster_role_arn",
                    output_s3_uri="output_s3_uri",
                    kms_key_id="kms_key_id",
                    output_format="CSV",
                    output_compression="SNAPPY",
                ),
            ),
        ),
        ProcessingInput(
            input_name="athena_dataset_definition",
            app_managed=True,
            dataset_definition=DatasetDefinition(
                local_path="/opt/ml/processing/input/dd",
                data_distribution_type="FullyReplicated",
                input_mode="File",
                athena_dataset_definition=AthenaDatasetDefinition(
                    catalog="catalog",
                    database="database",
                    work_group="workgroup",
                    query_string="query_string",
                    output_s3_uri="output_s3_uri",
                    kms_key_id="kms_key_id",
                    output_format="AVRO",
                    output_compression="ZLIB",
                ),
            ),
        ),
    ]
コード例 #8
0
def get_pipeline(
    region,
    sagemaker_project_arn=None,
    role=None,
    default_bucket=None,
    model_package_group_name="restatePackageGroup",  # Choose any name
    pipeline_name="restate-p-XXXXXXXXX",  # You can find your pipeline name in the Studio UI (project -> Pipelines -> name)
    base_job_prefix="restate",  # Choose any name
):
    """Gets a SageMaker ML Pipeline instance working with on RE data.
    Args:
        region: AWS region to create and run the pipeline.
        role: IAM role to create and run steps and pipeline.
        default_bucket: the bucket to use for storing the artifacts
    Returns:
        an instance of a pipeline
    """
    sagemaker_session = get_session(region, default_bucket)
    if role is None:
        role = sagemaker.session.get_execution_role(sagemaker_session)

    # Parameters for pipeline execution
    processing_instance_count = ParameterInteger(name="ProcessingInstanceCount", default_value=1)
    processing_instance_type = ParameterString(
        name="ProcessingInstanceType", default_value="ml.m5.2xlarge"
    )
    training_instance_type = ParameterString(
        name="TrainingInstanceType", default_value="ml.m5.xlarge"
    )
    model_approval_status = ParameterString(
        name="ModelApprovalStatus",
        default_value="PendingManualApproval",  # ModelApprovalStatus can be set to a default of "Approved" if you don't want manual approval.
    )
    input_data = ParameterString(
        name="InputDataUrl",
        default_value=f"",  # Change this to point to the s3 location of your raw input data.
    )

    data_sources = []
    # Sagemaker session
    sess = sagemaker_session

    # You can configure this with your own bucket name, e.g.
    # bucket = "my-bucket"
    bucket = sess.default_bucket()

    data_sources.append(
        ProcessingInput(
            input_name="restate-california",
            dataset_definition=DatasetDefinition(
                local_path="/opt/ml/processing/restate-california",
                data_distribution_type="FullyReplicated",
                # You can override below to point to other database or use different queries
                athena_dataset_definition=AthenaDatasetDefinition(
                    catalog="AwsDataCatalog",
                    database="restate",
                    query_string="SELECT * FROM restate.california_10",
                    output_s3_uri=f"s3://{bucket}/athena/",
                    output_format="PARQUET",
                ),
            ),
        )
    )

    print(f"Data Wrangler export storage bucket: {bucket}")

    # unique flow export ID
    flow_export_id = f"{time.strftime('%d-%H-%M-%S', time.gmtime())}-{str(uuid.uuid4())[:8]}"
    flow_export_name = f"flow-{flow_export_id}"

    # Output name is auto-generated from the select node's ID + output name from the flow file.
    output_name = "99ae1ec3-dd5f-453c-bfae-721dac423cd7.default"

    s3_output_prefix = f"export-{flow_export_name}/output"
    s3_output_path = f"s3://{bucket}/{s3_output_prefix}"
    print(f"Flow S3 export result path: {s3_output_path}")

    processing_job_output = ProcessingOutput(
        output_name=output_name,
        source="/opt/ml/processing/output",
        destination=s3_output_path,
        s3_upload_mode="EndOfJob",
    )

    # name of the flow file which should exist in the current notebook working directory
    flow_file_name = "sagemaker-pipeline/restate-athena-california.flow"

    # Load .flow file from current notebook working directory
    #!echo "Loading flow file from current notebook working directory: $PWD"

    with open(flow_file_name) as f:
        flow = json.load(f)

    # Upload flow to S3
    s3_client = boto3.client("s3")
    s3_client.upload_file(
        flow_file_name,
        bucket,
        f"data_wrangler_flows/{flow_export_name}.flow",
        ExtraArgs={"ServerSideEncryption": "aws:kms"},
    )

    flow_s3_uri = f"s3://{bucket}/data_wrangler_flows/{flow_export_name}.flow"

    print(f"Data Wrangler flow {flow_file_name} uploaded to {flow_s3_uri}")

    ## Input - Flow: restate-athena-russia.flow
    flow_input = ProcessingInput(
        source=flow_s3_uri,
        destination="/opt/ml/processing/flow",
        input_name="flow",
        s3_data_type="S3Prefix",
        s3_input_mode="File",
        s3_data_distribution_type="FullyReplicated",
    )

    # IAM role for executing the processing job.
    iam_role = role

    # Unique processing job name. Give a unique name every time you re-execute processing jobs
    processing_job_name = f"data-wrangler-flow-processing-{flow_export_id}"

    # Data Wrangler Container URL.
    container_uri = sagemaker.image_uris.retrieve(
        framework="data-wrangler",  # we are using the Sagemaker built in xgboost algorithm
        region=region,
    )

    # Processing Job Instance count and instance type.
    instance_count = 2
    instance_type = "ml.m5.4xlarge"

    # Size in GB of the EBS volume to use for storing data during processing
    volume_size_in_gb = 30

    # Content type for each output. Data Wrangler supports CSV as default and Parquet.
    output_content_type = "CSV"

    # Network Isolation mode; default is off
    enable_network_isolation = False

    # List of tags to be passed to the processing job
    user_tags = []

    # Output configuration used as processing job container arguments
    output_config = {output_name: {"content_type": output_content_type}}

    # KMS key for per object encryption; default is None
    kms_key = None

    processor = Processor(
        role=iam_role,
        image_uri=container_uri,
        instance_count=instance_count,
        instance_type=instance_type,
        volume_size_in_gb=volume_size_in_gb,
        network_config=NetworkConfig(enable_network_isolation=enable_network_isolation),
        sagemaker_session=sess,
        output_kms_key=kms_key,
        tags=user_tags,
    )

    data_wrangler_step = ProcessingStep(
        name="DataWranglerProcess",
        processor=processor,
        inputs=[flow_input] + data_sources,
        outputs=[processing_job_output],
        job_arguments=[f"--output-config '{json.dumps(output_config)}'"],
    )

    # Processing step for feature engineering
    # this processor does not have awswrangler installed
    sklearn_processor = SKLearnProcessor(
        framework_version="0.23-1",
        instance_type=processing_instance_type,
        instance_count=processing_instance_count,
        base_job_name=f"{base_job_prefix}/sklearn-restate-preprocess",  # choose any name
        sagemaker_session=sagemaker_session,
        role=role,
    )

    step_process = ProcessingStep(
        name="Preprocess",  # choose any name
        processor=sklearn_processor,
        inputs=[
            ProcessingInput(
                source=data_wrangler_step.properties.ProcessingOutputConfig.Outputs[
                    output_name
                ].S3Output.S3Uri,
                destination="/opt/ml/processing/data/raw-data-dir",
            )
        ],
        outputs=[
            ProcessingOutput(output_name="train", source="/opt/ml/processing/train"),
            ProcessingOutput(output_name="validation", source="/opt/ml/processing/validation"),
            ProcessingOutput(output_name="test", source="/opt/ml/processing/test"),
        ],
        code=os.path.join(BASE_DIR, "preprocess.py"),
        job_arguments=[
            "--input-data",
            data_wrangler_step.properties.ProcessingOutputConfig.Outputs[
                output_name
            ].S3Output.S3Uri,
        ],
    )

    # Training step for generating model artifacts
    model_path = f"s3://{sagemaker_session.default_bucket()}/{base_job_prefix}/restateTrain"
    model_bucket_key = f"{sagemaker_session.default_bucket()}/{base_job_prefix}/restateTrain"
    cache_config = CacheConfig(enable_caching=True, expire_after="30d")

    xgb_image_uri = sagemaker.image_uris.retrieve(
        framework="xgboost",  # we are using the Sagemaker built in xgboost algorithm
        region=region,
        version="1.0-1",
        py_version="py3",
        instance_type=training_instance_type,
    )
    xgb_train = Estimator(
        image_uri=xgb_image_uri,
        instance_type=training_instance_type,
        instance_count=1,
        output_path=model_path,
        base_job_name=f"{base_job_prefix}/restate-xgb-train",
        sagemaker_session=sagemaker_session,
        role=role,
    )
    xgb_train.set_hyperparameters(
        #    #objective="binary:logistic",
        #    objective="reg:linear",
        num_round=50,
        #    max_depth=5,
        #    eta=0.2,
        #    gamma=4,
        #    min_child_weight=6,
        #    subsample=0.7,
        #    silent=0,
    )

    xgb_train.set_hyperparameters(grow_policy="lossguide")

    xgb_objective_metric_name = "validation:mse"
    xgb_hyperparameter_ranges = {
        "max_depth": IntegerParameter(2, 10, scaling_type="Linear"),
    }

    xgb_tuner_log = HyperparameterTuner(
        xgb_train,
        xgb_objective_metric_name,
        xgb_hyperparameter_ranges,
        max_jobs=3,
        max_parallel_jobs=3,
        strategy="Random",
        objective_type="Minimize",
    )

    xgb_step_tuning = TuningStep(
        name="XGBHPTune",
        tuner=xgb_tuner_log,
        inputs={
            "train": TrainingInput(
                s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
                    "train"
                ].S3Output.S3Uri,
                content_type="text/csv",
            ),
            "validation": TrainingInput(
                s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
                    "validation"
                ].S3Output.S3Uri,
                content_type="text/csv",
            ),
        },
        cache_config=cache_config,
    )

    # dtree_image_uri = '625467769535.dkr.ecr.ap-southeast-1.amazonaws.com/sagemaker-decision-tree:latest'
    dtree_image_uri = sagemaker_session.sagemaker_client.describe_image_version(
        ImageName="restate-dtree"
    )["ContainerImage"]

    dtree_train = Estimator(
        image_uri=dtree_image_uri,
        role=role,
        instance_count=1,
        instance_type=training_instance_type,
        base_job_name=f"{base_job_prefix}/restate-dtree-train",
        output_path=model_path,
        sagemaker_session=sagemaker_session,
    )

    dtree_objective_metric_name = "validation:mse"
    dtree_metric_definitions = [{"Name": "validation:mse", "Regex": "mse:(\S+)"}]

    dtree_hyperparameter_ranges = {
        "max_depth": IntegerParameter(10, 50, scaling_type="Linear"),
        "max_leaf_nodes": IntegerParameter(2, 12, scaling_type="Linear"),
    }

    dtree_tuner_log = HyperparameterTuner(
        dtree_train,
        dtree_objective_metric_name,
        dtree_hyperparameter_ranges,
        dtree_metric_definitions,
        max_jobs=3,
        max_parallel_jobs=3,
        strategy="Random",
        objective_type="Minimize",
    )

    dtree_step_tuning = TuningStep(
        name="DTreeHPTune",
        tuner=dtree_tuner_log,
        inputs={
            "training": TrainingInput(
                s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
                    "train"
                ].S3Output.S3Uri,
                content_type="text/csv",
            ),
            "validation": TrainingInput(
                s3_data=step_process.properties.ProcessingOutputConfig.Outputs[
                    "validation"
                ].S3Output.S3Uri,
                content_type="text/csv",
            ),
        },
        cache_config=cache_config,
    )

    dtree_script_eval = ScriptProcessor(
        image_uri=dtree_image_uri,
        command=["python3"],
        instance_type=processing_instance_type,
        instance_count=1,
        base_job_name=f"{base_job_prefix}/script-dtree-eval",
        sagemaker_session=sagemaker_session,
        role=role,
    )

    dtree_evaluation_report = PropertyFile(
        name="EvaluationReportDTree",
        output_name="dtree_evaluation",
        path="dtree_evaluation.json",
    )

    dtree_step_eval = ProcessingStep(
        name="DTreeEval",
        processor=dtree_script_eval,
        inputs=[
            ProcessingInput(
                # source=dtree_step_train.properties.ModelArtifacts.S3ModelArtifacts,
                source=dtree_step_tuning.get_top_model_s3_uri(top_k=0, s3_bucket=model_bucket_key),
                destination="/opt/ml/processing/model",
            ),
            ProcessingInput(
                source=step_process.properties.ProcessingOutputConfig.Outputs[
                    "test"
                ].S3Output.S3Uri,
                destination="/opt/ml/processing/test",
            ),
        ],
        outputs=[
            ProcessingOutput(
                output_name="dtree_evaluation", source="/opt/ml/processing/evaluation"
            ),
        ],
        code=os.path.join(BASE_DIR, "dtree_evaluate.py"),
        property_files=[dtree_evaluation_report],
    )

    xgb_script_eval = ScriptProcessor(
        image_uri=xgb_image_uri,
        command=["python3"],
        instance_type=processing_instance_type,
        instance_count=1,
        base_job_name=f"{base_job_prefix}/script-xgb-eval",
        sagemaker_session=sagemaker_session,
        role=role,
    )

    xgb_evaluation_report = PropertyFile(
        name="EvaluationReportXGBoost",
        output_name="xgb_evaluation",
        path="xgb_evaluation.json",
    )

    xgb_step_eval = ProcessingStep(
        name="XGBEval",
        processor=xgb_script_eval,
        inputs=[
            ProcessingInput(
                source=xgb_step_tuning.get_top_model_s3_uri(top_k=0, s3_bucket=model_bucket_key),
                destination="/opt/ml/processing/model",
            ),
            ProcessingInput(
                source=step_process.properties.ProcessingOutputConfig.Outputs[
                    "test"
                ].S3Output.S3Uri,
                destination="/opt/ml/processing/test",
            ),
        ],
        outputs=[
            ProcessingOutput(output_name="xgb_evaluation", source="/opt/ml/processing/evaluation"),
        ],
        code=os.path.join(BASE_DIR, "xgb_evaluate.py"),
        property_files=[xgb_evaluation_report],
    )

    xgb_model_metrics = ModelMetrics(
        model_statistics=MetricsSource(
            s3_uri="{}/xgb_evaluation.json".format(
                xgb_step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"]["S3Uri"]
            ),
            content_type="application/json",
        )
    )

    dtree_model_metrics = ModelMetrics(
        model_statistics=MetricsSource(
            s3_uri="{}/dtree_evaluation.json".format(
                dtree_step_eval.arguments["ProcessingOutputConfig"]["Outputs"][0]["S3Output"][
                    "S3Uri"
                ]
            ),
            content_type="application/json",
        )
    )

    xgb_eval_metrics = JsonGet(
        step=xgb_step_eval,
        property_file=xgb_evaluation_report,
        json_path="regression_metrics.r2s.value",  # This should follow the structure of your report_dict defined in the evaluate.py file.
    )

    dtree_eval_metrics = JsonGet(
        step=dtree_step_eval,
        property_file=dtree_evaluation_report,
        json_path="regression_metrics.r2s.value",  # This should follow the structure of your report_dict defined in the evaluate.py file.
    )

    # Register model step that will be conditionally executed
    dtree_step_register = RegisterModel(
        name="DTreeReg",
        estimator=dtree_train,
        model_data=dtree_step_tuning.get_top_model_s3_uri(top_k=0, s3_bucket=model_bucket_key),
        content_types=["text/csv"],
        response_types=["text/csv"],
        inference_instances=["ml.t2.medium", "ml.m5.large"],
        transform_instances=["ml.m5.large"],
        model_package_group_name=model_package_group_name,
        approval_status=model_approval_status,
        model_metrics=dtree_model_metrics,
    )

    # Register model step that will be conditionally executed
    xgb_step_register = RegisterModel(
        name="XGBReg",
        estimator=xgb_train,
        model_data=xgb_step_tuning.get_top_model_s3_uri(top_k=0, s3_bucket=model_bucket_key),
        content_types=["text/csv"],
        response_types=["text/csv"],
        inference_instances=["ml.t2.medium", "ml.m5.large"],
        transform_instances=["ml.m5.large"],
        model_package_group_name=model_package_group_name,
        approval_status=model_approval_status,
        model_metrics=xgb_model_metrics,
    )

    # Condition step for evaluating model quality and branching execution
    cond_lte = ConditionGreaterThanOrEqualTo(  # You can change the condition here
        left=JsonGet(
            step=dtree_step_eval,
            property_file=dtree_evaluation_report,
            json_path="regression_metrics.r2s.value",  # This should follow the structure of your report_dict defined in the evaluate.py file.
        ),
        right=JsonGet(
            step=xgb_step_eval,
            property_file=xgb_evaluation_report,
            json_path="regression_metrics.r2s.value",  # This should follow the structure of your report_dict defined in the evaluate.py file.
        ),  # You can change the threshold here
    )

    step_cond = ConditionStep(
        name="AccuracyCond",
        conditions=[cond_lte],
        if_steps=[dtree_step_register],
        else_steps=[xgb_step_register],
    )
    create_date = time.strftime("%Y-%m-%d-%H-%M-%S")

    # Pipeline instance
    pipeline = Pipeline(
        name=pipeline_name,
        parameters=[
            processing_instance_type,
            processing_instance_count,
            training_instance_type,
            model_approval_status,
            input_data
        ],
        pipeline_experiment_config=PipelineExperimentConfig(
            pipeline_name + "-" + create_date, "restate-{}".format(create_date)
        ),
        steps=[
            data_wrangler_step,
            step_process,
            dtree_step_tuning,
            xgb_step_tuning,
            dtree_step_eval,
            xgb_step_eval,
            step_cond,
        ],
        sagemaker_session=sagemaker_session,
    )
    return pipeline