def _construct_toposorted_chain_of_tasks(self) -> Chain:
        """Take the directed graph and toposort so that we can efficiently
        organize our workflow, i.e. parallelize where possible.

        if we have 2 elements where one of both is an Ellipsis object we need to orchestrate just 1 job.
        In the other case we will loop over the toposorted dag and assign a stepfunctions task
        or assign multiple tasks in parallel.

        Returns: toposorted chain of tasks
        """
        self.chain_of_tasks = Chain()
        directed_graph_toposorted = list(toposort.toposort(
            self.directed_graph))
        if self._is_one_task(
                directed_graph_toposorted=directed_graph_toposorted):
            sfn_task = self.add_task(next(iter(directed_graph_toposorted[0])))
            self.chain_of_tasks.append(sfn_task)
        else:
            for element in directed_graph_toposorted:
                if len(element) == 1:
                    sfn_task = self.add_task(next(iter(element)))
                elif len(element) > 1:
                    sfn_task = self.add_parallel_tasks(element)
                else:
                    raise StepfunctionsWorkflowException(
                        "cannot have an index in the directed graph with 0 elements"
                    )
                self.chain_of_tasks.append(sfn_task)
        return self.chain_of_tasks
def test_training_step(pca_estimator_fixture, record_set_fixture, sfn_client,
                       sfn_role_arn):
    # Build workflow definition
    job_name = generate_job_name()
    training_step = TrainingStep('create_training_job_step',
                                 estimator=pca_estimator_fixture,
                                 job_name=job_name,
                                 data=record_set_fixture,
                                 mini_batch_size=200)
    workflow_graph = Chain([training_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-training-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TrainingJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
def test_model_step(trained_estimator, sfn_client, sagemaker_session,
                    sfn_role_arn):
    # Build workflow definition
    model_name = generate_job_name()
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=model_name)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)
Exemplo n.º 4
0
def workflow(client):
    execution_input = ExecutionInput()

    test_step_01 = Pass(state_id='StateOne',
                        parameters={
                            'ParamA': execution_input['Key02']['Key03'],
                            'ParamD': execution_input['Key01']['Key03'],
                        })

    test_step_02 = Pass(state_id='StateTwo',
                        parameters={
                            'ParamC': execution_input["Key05"],
                            "ParamB": "SampleValueB",
                            "ParamE":
                            test_step_01.output()["Response"]["Key04"]
                        })

    test_step_03 = Pass(state_id='StateThree',
                        parameters={
                            'ParamG': "SampleValueG",
                            "ParamF": execution_input["Key06"],
                            "ParamH": "SampleValueH",
                            "ParamI": test_step_02.output()
                        })

    workflow_definition = Chain([test_step_01, test_step_02, test_step_03])
    workflow = Workflow(name='TestWorkflow',
                        definition=workflow_definition,
                        role='testRoleArn',
                        execution_input=execution_input,
                        client=client)
    return workflow
Exemplo n.º 5
0
def test_step_input_order_validation():
    workflow_input = ExecutionInput()

    test_step_01 = Pass(state_id='StateOne',
                        parameters={
                            'ParamA': workflow_input['Key02']['Key03'],
                            'ParamD': workflow_input['Key01']['Key03'],
                        })

    test_step_02 = Pass(state_id='StateTwo',
                        parameters={
                            'ParamC': workflow_input["Key05"],
                            "ParamB": "SampleValueB",
                            "ParamE":
                            test_step_01.output()["Response"]["Key04"]
                        })

    test_step_03 = Pass(state_id='StateThree',
                        parameters={
                            'ParamG': "SampleValueG",
                            "ParamF": workflow_input["Key06"],
                            "ParamH": "SampleValueH"
                        })

    workflow_definition = Chain([test_step_01, test_step_03, test_step_02])

    with pytest.raises(ValueError):
        result = Graph(workflow_definition).to_dict()
Exemplo n.º 6
0
def test_choice_state_with_placeholders():

    first_state = Task(
        'FirstState',
        resource='arn:aws:lambda:us-east-1:1234567890:function:FirstState')
    retry = Chain([Pass('Retry'), Pass('Cleanup'), first_state])

    choice_state = Choice('Is Completed?')
    choice_state.add_choice(
        ChoiceRule.BooleanEquals(choice_state.output()["Completed"], True),
        Succeed('Complete'))
    choice_state.add_choice(
        ChoiceRule.BooleanEquals(choice_state.output()["Completed"], False),
        retry)

    first_state.next(choice_state)

    result = Graph(first_state).to_dict()

    expected_repr = {
        "StartAt": "FirstState",
        "States": {
            "FirstState": {
                "Resource":
                "arn:aws:lambda:us-east-1:1234567890:function:FirstState",
                "Type": "Task",
                "Next": "Is Completed?"
            },
            "Is Completed?": {
                "Type":
                "Choice",
                "Choices": [{
                    "Variable": "$['Completed']",
                    "BooleanEquals": True,
                    "Next": "Complete"
                }, {
                    "Variable": "$['Completed']",
                    "BooleanEquals": False,
                    "Next": "Retry"
                }]
            },
            "Complete": {
                "Type": "Succeed"
            },
            "Retry": {
                "Type": "Pass",
                "Next": "Cleanup"
            },
            "Cleanup": {
                "Type": "Pass",
                "Next": "FirstState"
            }
        }
    }

    assert result == expected_repr
Exemplo n.º 7
0
def get_existing_training_pipeline(workflow_arn):
    """
    Create a dummpy implementation of get existing training pipeline
    """
    training_pipeline = Workflow(
        name="training_pipeline_name",
        definition=Chain([]),
        role="workflow_execution_role",
    )

    return training_pipeline.attach(workflow_arn)
Exemplo n.º 8
0
def get_existing_monitor_pipeline(workflow_arn):
    """
    Create a dummpy implementation of get existing data pipeline
    """
    data_pipeline = Workflow(
        name="data_pipeline_name",
        definition=Chain([]),
        role="workflow_execution_role",
    )

    return data_pipeline.attach(workflow_arn)
def test_chaining_choice_sets_default_field():
    s1_pass = Pass('Step - One')
    s2_choice = Choice('Step - Two')
    s3_pass = Pass('Step - Three')

    chain1 = Chain([s1_pass, s2_choice, s3_pass])
    assert chain1.steps == [s1_pass, s2_choice, s3_pass]
    assert s1_pass.next_step == s2_choice
    assert s2_choice.default == s3_pass
    assert s2_choice.next_step is None  # Choice steps do not have next_step
    assert s3_pass.next_step is None
def test_model_step_with_placeholders(trained_estimator, sfn_client,
                                      sagemaker_session, sfn_role_arn):
    # Build workflow definition
    execution_input = ExecutionInput(schema={
        'ModelName': str,
        'Mode': str,
        'Tags': list
    })

    parameters = {
        'PrimaryContainer': {
            'Mode': execution_input['Mode']
        },
        'Tags': execution_input['Tags']
    }

    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=execution_input['ModelName'],
                           parameters=parameters)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        inputs = {
            'ModelName': generate_job_name(),
            'Mode': 'SingleModel',
            'Tags': [{
                'Key': 'Environment',
                'Value': 'test'
            }]
        }

        # Execute workflow
        execution = workflow.execute(inputs=inputs)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)
def test_tuning_step(sfn_client, record_set_for_hyperparameter_tuning,
                     sagemaker_role_arn, sfn_role_arn):
    job_name = generate_job_name()

    kmeans = KMeans(role=sagemaker_role_arn,
                    instance_count=1,
                    instance_type=INSTANCE_TYPE,
                    k=10)

    hyperparameter_ranges = {
        "extra_center_factor": IntegerParameter(4, 10),
        "mini_batch_size": IntegerParameter(10, 100),
        "epochs": IntegerParameter(1, 2),
        "init_method": CategoricalParameter(["kmeans++", "random"]),
    }

    tuner = HyperparameterTuner(
        estimator=kmeans,
        objective_metric_name="test:msd",
        hyperparameter_ranges=hyperparameter_ranges,
        objective_type="Minimize",
        max_jobs=2,
        max_parallel_jobs=2,
    )

    # Build workflow definition
    tuning_step = TuningStep('Tuning',
                             tuner=tuner,
                             job_name=job_name,
                             data=record_set_for_hyperparameter_tuning)
    tuning_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([tuning_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-tuning-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get(
            "HyperParameterTuningJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
Exemplo n.º 12
0
def get_existing_inference_pipeline(workflow_arn):
    """
    Create a dummy implementation to get existing training pipeline

    TODO: This could be a good PR for the SDK.
    """
    inference_pipeline = Workflow(
        name="inference_pipeline_name",
        definition=Chain([]),
        role="workflow_execution_role",
    )

    return inference_pipeline.attach(workflow_arn)
def test_map_state_with_placeholders():
    workflow_input = ExecutionInput()
    step_result = StepResult()

    map_state = Map(state_id="MapState01",
                    result_selector={
                        "foo": step_result["foo"],
                        "bar": step_result["bar1"]["bar2"]
                    })
    iterator_state = Pass("TrainIterator",
                          parameters={
                              "ParamA": map_state.output()["X"]["Y"],
                              "ParamB":
                              workflow_input["Key01"]["Key02"]["Key03"]
                          })

    map_state.attach_iterator(iterator_state)
    workflow_definition = Chain([map_state])

    expected_repr = {
        "StartAt": "MapState01",
        "States": {
            "MapState01": {
                "Type": "Map",
                "ResultSelector": {
                    "foo.$": "$['foo']",
                    "bar.$": "$['bar1']['bar2']"
                },
                "End": True,
                "Iterator": {
                    "StartAt": "TrainIterator",
                    "States": {
                        "TrainIterator": {
                            "Parameters": {
                                "ParamA.$":
                                "$['X']['Y']",
                                "ParamB.$":
                                "$$.Execution.Input['Key01']['Key02']['Key03']"
                            },
                            "Type": "Pass",
                            "End": True
                        }
                    }
                }
            }
        }
    }

    result = Graph(workflow_definition).to_dict()
    assert result == expected_repr
Exemplo n.º 14
0
def test_wait_loop():
    first_state = Task(
        'FirstState',
        resource='arn:aws:lambda:us-east-1:1234567890:function:FirstState')
    retry = Chain([Pass('Retry'), Pass('Cleanup'), first_state])

    choice_state = Choice('Is Completed?')
    choice_state.add_choice(ChoiceRule.BooleanEquals('$.Completed', True),
                            Succeed('Complete'))
    choice_state.add_choice(ChoiceRule.BooleanEquals('$.Completed', False),
                            retry)
    first_state.next(choice_state)

    result = Graph(first_state).to_dict()
    assert result == {
        'StartAt': 'FirstState',
        'States': {
            'FirstState': {
                'Type': 'Task',
                'Resource':
                'arn:aws:lambda:us-east-1:1234567890:function:FirstState',
                'Next': 'Is Completed?'
            },
            'Is Completed?': {
                'Type':
                'Choice',
                'Choices': [{
                    'Variable': '$.Completed',
                    'BooleanEquals': True,
                    'Next': 'Complete'
                }, {
                    'Variable': '$.Completed',
                    'BooleanEquals': False,
                    'Next': 'Retry'
                }]
            },
            'Complete': {
                'Type': 'Succeed'
            },
            'Retry': {
                'Type': 'Pass',
                'Next': 'Cleanup',
            },
            'Cleanup': {
                'Type': 'Pass',
                'Next': 'FirstState'
            }
        }
    }
def test_chaining_choice_with_existing_default_overrides_value(caplog):
    s1_pass = Pass('Step - One')
    s2_choice = Choice('Step - Two')
    s3_pass = Pass('Step - Three')

    s2_choice.default_choice(s3_pass)

    # Chain s2_choice when default_choice is already set will trigger Warning message
    with caplog.at_level(logging.WARNING):
        Chain([s2_choice, s1_pass])
        expected_warning = f'Chaining Choice state: Overwriting {s2_choice.state_id}\'s current default_choice ({s3_pass.state_id}) with {s1_pass.state_id}'
        assert expected_warning in caplog.text
        assert 'WARNING' in caplog.text
    assert s2_choice.default == s1_pass
    assert s2_choice.next_step is None  # Choice steps do not have next_step
def test_transform_step(trained_estimator, sfn_client, sfn_role_arn):
    # Create transformer from previously created estimator
    job_name = generate_job_name()
    pca_transformer = trained_estimator.transformer(
        instance_count=INSTANCE_COUNT, instance_type=INSTANCE_TYPE)

    # Create a model step to save the model
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=job_name)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)

    # Upload data for transformation to S3
    data_path = os.path.join(DATA_DIR, "one_p_mnist")
    transform_input_path = os.path.join(data_path, "transform_input.csv")
    transform_input_key_prefix = "integ-test-data/one_p_mnist/transform"
    transform_input = pca_transformer.sagemaker_session.upload_data(
        path=transform_input_path, key_prefix=transform_input_key_prefix)

    # Build workflow definition
    transform_step = TransformStep('create_transform_job_step',
                                   pca_transformer,
                                   job_name=job_name,
                                   model_name=job_name,
                                   data=transform_input,
                                   content_type="text/csv")
    transform_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step, transform_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-transform-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("TransformJobStatus") == "Completed"

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
Exemplo n.º 17
0
    def _integrate_notification_in_workflow(self,
                                            chain_of_tasks: Chain) -> Chain:
        """If a notification is defined we configure an SNS with email
        subscription to alert the user if the stepfunctions workflow failed or
        succeeded.

        :param chain_of_tasks: the workflow definition that contains all the steps we want to execute.
        :return: if notification is set, we adapt the workflow to include an SnsPublishStep on failure or on success.
        If notification is not set, we return the workflow as we received it.
        """
        if self.notification:
            logger.debug(
                "A notification is configured, "
                "implementing a notification on Error or when the stepfunctions workflow succeeds."
            )
            failure_notification = SnsPublishStep(
                "FailureNotification",
                parameters={
                    "TopicArn":
                    self.notification.get_topic_arn(),
                    "Message":
                    f"Stepfunctions workflow {self.unique_name} Failed.",
                },
            )
            pass_notification = SnsPublishStep(
                "SuccessNotification",
                parameters={
                    "TopicArn":
                    self.notification.get_topic_arn(),
                    "Message":
                    f"Stepfunctions workflow {self.unique_name} Succeeded.",
                },
            )

            catch_error = Catch(error_equals=["States.ALL"],
                                next_step=failure_notification)
            workflow_with_notification = Parallel(state_id="notification")
            workflow_with_notification.add_branch(chain_of_tasks)
            workflow_with_notification.add_catch(catch_error)
            workflow_with_notification.next(pass_notification)
            return Chain([workflow_with_notification])
        logger.debug(
            "No notification is configured, returning the workflow definition."
        )
        return chain_of_tasks
Exemplo n.º 18
0
def create_sfn_workflow(params, steps):
    sfn_workflow_name = params['sfn-workflow-name']
    workflow_execution_role = params['sfn-role-arn']

    workflow_graph = Chain(steps)

    branching_workflow = Workflow(
        name=sfn_workflow_name,
        definition=workflow_graph,
        role=workflow_execution_role,
    )

    branching_workflow.create()
    branching_workflow.update(workflow_graph)

    time.sleep(5)

    return branching_workflow
def test_create_endpoint_step(trained_estimator, record_set_fixture,
                              sfn_client, sagemaker_session, sfn_role_arn):
    # Setup: Create model and endpoint config for trained estimator in SageMaker
    model = trained_estimator.create_model()
    model._create_sagemaker_model(instance_type=INSTANCE_TYPE)
    endpoint_config = model.sagemaker_session.create_endpoint_config(
        name=model.name,
        model_name=model.name,
        initial_instance_count=INSTANCE_COUNT,
        instance_type=INSTANCE_TYPE)
    # End of Setup

    # Build workflow definition
    endpoint_name = unique_name_from_base("integ-test-endpoint")
    endpoint_step = EndpointStep('create_endpoint_step',
                                 endpoint_name=endpoint_name,
                                 endpoint_config_name=model.name)
    endpoint_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([endpoint_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-create-endpoint-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        endpoint_arn = execution_output.get("EndpointArn")
        assert execution_output.get("EndpointArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        delete_sagemaker_endpoint(endpoint_name, sagemaker_session)
        delete_sagemaker_endpoint_config(model.name, sagemaker_session)
        delete_sagemaker_model(model.name, sagemaker_session)
Exemplo n.º 20
0
def test_map_state_with_placeholders():
    workflow_input = ExecutionInput()

    map_state = Map('MapState01')
    iterator_state = Pass('TrainIterator',
                          parameters={
                              'ParamA': map_state.output()['X']["Y"],
                              'ParamB':
                              workflow_input["Key01"]["Key02"]["Key03"]
                          })

    map_state.attach_iterator(iterator_state)
    workflow_definition = Chain([map_state])

    expected_repr = {
        "StartAt": "MapState01",
        "States": {
            "MapState01": {
                "Type": "Map",
                "End": True,
                "Iterator": {
                    "StartAt": "TrainIterator",
                    "States": {
                        "TrainIterator": {
                            "Parameters": {
                                "ParamA.$":
                                "$['X']['Y']",
                                "ParamB.$":
                                "$$.Execution.Input['Key01']['Key02']['Key03']"
                            },
                            "Type": "Pass",
                            "End": True
                        }
                    }
                }
            }
        }
    }

    result = Graph(workflow_definition).to_dict()
    assert result == expected_repr
    def build_workflow_definition(self):
        """
        Build the workflow definition for the training pipeline with all the states involved.

        Returns:
            :class:`~stepfunctions.steps.states.Chain`: Workflow definition as a chain of states involved in the the training pipeline.
        """
        default_name = self.pipeline_name

        train_instance_type = self.estimator.train_instance_type
        train_instance_count = self.estimator.train_instance_count

        training_step = TrainingStep(
            StepId.Train.value,
            estimator=self.estimator,
            job_name=default_name + '/estimator-source',
            data=self.inputs,
        )

        model = self.estimator.create_model()
        model_step = ModelStep(StepId.CreateModel.value,
                               instance_type=train_instance_type,
                               model=model,
                               model_name=default_name)

        endpoint_config_step = EndpointConfigStep(
            StepId.ConfigureEndpoint.value,
            endpoint_config_name=default_name,
            model_name=default_name,
            initial_instance_count=train_instance_count,
            instance_type=train_instance_type)
        deploy_step = EndpointStep(
            StepId.Deploy.value,
            endpoint_name=default_name,
            endpoint_config_name=default_name,
        )

        return Chain(
            [training_step, model_step, endpoint_config_step, deploy_step])
def test_nested_chain_is_now_allowed():
    chain = Chain([Chain([Pass('S1')])])
Exemplo n.º 23
0
def define_training_pipeline(
    sm_role,
    workflow_execution_role,
    training_pipeline_name,
    return_yaml=True,
    dump_yaml_file="templates/sagemaker_training_pipeline.yaml",
    kms_key_id=None,
):
    """
    Return YAML definition of the training pipeline, which consists of multiple
    Amazon StepFunction steps

    sm_role:                    ARN of the SageMaker execution role
    workflow_execution_role:    ARN of the StepFunction execution role
    return_yaml:                Return YAML representation or not, if False,
                                it returns an instance of
                                    `stepfunctions.workflow.WorkflowObject`
    dump_yaml_file:             If not None, a YAML file will be generated at
                                    this file location

    """

    # Pass required parameters dynamically for each execution using placeholders.
    execution_input = ExecutionInput(
        schema={
            "InputDataURL": str,
            "PreprocessingJobName": str,
            "PreprocessingCodeURL": str,
            "TrainingJobName": str,
            # Prevent sagemaker config hardcode sagemaker_submit_directory in
            # workflow definition
            "SMSubmitDirURL": str,
            # Prevent sagemaker config hardcode sagemaker_region in workflow definition
            "SMRegion": str,
            "EvaluationProcessingJobName": str,
            "EvaluationCodeURL": str,
            "EvaluationResultURL": str,
            "PreprocessedTrainDataURL": str,
            "PreprocessedTestDataURL": str,
            "PreprocessedModelURL": str,
            "SMOutputDataURL": str,
            "SMDebugOutputURL": str,
        })
    """
    Data pre-processing and feature engineering
    """
    sklearn_processor = SKLearnProcessor(
        framework_version="0.20.0",
        role=sm_role,
        instance_type="ml.m5.xlarge",
        instance_count=1,
        max_runtime_in_seconds=1200,
    )

    # Create ProcessingInputs and ProcessingOutputs objects for Inputs and
    # Outputs respectively for the SageMaker Processing Job
    inputs = [
        ProcessingInput(
            source=execution_input["InputDataURL"],
            destination="/opt/ml/processing/input",
            input_name="input-1",
        ),
        ProcessingInput(
            source=execution_input["PreprocessingCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
    ]

    outputs = [
        ProcessingOutput(
            source="/opt/ml/processing/train",
            destination=execution_input["PreprocessedTrainDataURL"],
            output_name="train_data",
        ),
        ProcessingOutput(
            source="/opt/ml/processing/test",
            destination=execution_input["PreprocessedTestDataURL"],
            output_name="test_data",
        ),
        ProcessingOutput(
            source="/opt/ml/processing/model",
            destination=execution_input["PreprocessedModelURL"],
            output_name="proc_model",
        ),
    ]

    processing_step = ProcessingStep(
        "SageMaker pre-processing step",
        processor=sklearn_processor,
        job_name=execution_input["PreprocessingJobName"],
        inputs=inputs,
        outputs=outputs,
        container_arguments=[
            "--train-test-split-ratio", "0.2", "--mode", "train"
        ],
        container_entrypoint=[
            "python3",
            "/opt/ml/processing/input/code/preprocessing.py",
        ],
        kms_key_id=kms_key_id,
    )
    """
    Training using the pre-processed data
    """
    sklearn = SKLearn(
        entry_point="../../src/mlmax/train.py",
        train_instance_type="ml.m5.xlarge",
        role=sm_role,
        py_version="py3",
        framework_version="0.20.0",
        output_kms_key=kms_key_id,
    )

    training_step = MLMaxTrainingStep(
        "SageMaker Training Step",
        estimator=sklearn,
        job_name=execution_input["TrainingJobName"],
        train_data=execution_input["PreprocessedTrainDataURL"],
        test_data=execution_input["PreprocessedTestDataURL"],
        sm_submit_url=execution_input["SMSubmitDirURL"],
        sm_region=execution_input["SMRegion"],
        sm_output_data=execution_input["SMOutputDataURL"],
        sm_debug_output_data=execution_input["SMDebugOutputURL"],
        wait_for_completion=True,
    )
    """
    Model evaluation
    """
    # Create input and output objects for Model Evaluation ProcessingStep.
    inputs_evaluation = [
        ProcessingInput(
            source=execution_input["PreprocessedTestDataURL"],
            destination="/opt/ml/processing/test",
            input_name="input-1",
        ),
        ProcessingInput(
            source=training_step.get_expected_model().model_data,
            destination="/opt/ml/processing/model",
            input_name="input-2",
        ),
        ProcessingInput(
            source=execution_input["EvaluationCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
    ]

    outputs_evaluation = [
        ProcessingOutput(
            source="/opt/ml/processing/evaluation",
            destination=execution_input["EvaluationResultURL"],
            output_name="evaluation",
        ),
    ]

    model_evaluation_processor = SKLearnProcessor(
        framework_version="0.20.0",
        role=sm_role,
        instance_type="ml.m5.xlarge",
        instance_count=1,
        max_runtime_in_seconds=1200,
    )

    processing_evaluation_step = ProcessingStep(
        "SageMaker Processing Model Evaluation step",
        processor=model_evaluation_processor,
        job_name=execution_input["EvaluationProcessingJobName"],
        inputs=inputs_evaluation,
        outputs=outputs_evaluation,
        container_entrypoint=[
            "python3", "/opt/ml/processing/input/code/evaluation.py"
        ],
    )

    # Create Fail state to mark the workflow failed in case any of the steps fail.
    failed_state_sagemaker_processing_failure = stepfunctions.steps.states.Fail(
        "ML Workflow failed", cause="SageMakerProcessingJobFailed")

    # Add the Error handling in the workflow
    catch_state_processing = stepfunctions.steps.states.Catch(
        error_equals=["States.TaskFailed"],
        next_step=failed_state_sagemaker_processing_failure,
    )
    processing_step.add_catch(catch_state_processing)
    processing_evaluation_step.add_catch(catch_state_processing)
    training_step.add_catch(catch_state_processing)

    # Create the Workflow
    workflow_graph = Chain(
        [processing_step, training_step, processing_evaluation_step])
    training_pipeline = Workflow(
        name=training_pipeline_name,
        definition=workflow_graph,
        role=workflow_execution_role,
    )
    return training_pipeline
    def build_workflow_definition(self):
        """
        Build the workflow definition for the inference pipeline with all the states involved.

        Returns:
            :class:`~stepfunctions.steps.states.Chain`: Workflow definition as a chain of states involved in the the inference pipeline.
        """
        default_name = self.pipeline_name

        train_instance_type = self.preprocessor.train_instance_type
        train_instance_count = self.preprocessor.train_instance_count

        # Preprocessor for feature transformation
        preprocessor_train_step = TrainingStep(
            StepId.TrainPreprocessor.value,
            estimator=self.preprocessor,
            job_name=default_name + '/preprocessor-source',
            data=self.inputs,
        )
        preprocessor_model = self.preprocessor.create_model()
        preprocessor_model_step = ModelStep(
            StepId.CreatePreprocessorModel.value,
            instance_type=train_instance_type,
            model=preprocessor_model,
            model_name=default_name)
        preprocessor_transform_step = TransformStep(
            StepId.TransformInput.value,
            transformer=self.preprocessor.transformer(
                instance_count=train_instance_count,
                instance_type=train_instance_type,
                max_payload=20),
            job_name=default_name,
            model_name=default_name,
            data=self.inputs['train'],
            compression_type=self.compression_type,
            content_type=self.content_type)

        # Training
        train_instance_type = self.estimator.train_instance_type
        train_instance_count = self.estimator.train_instance_count

        training_step = TrainingStep(
            StepId.Train.value,
            estimator=self.estimator,
            job_name=default_name + '/estimator-source',
            data=self.inputs,
        )

        pipeline_model = PipelineModel(name='PipelineModel',
                                       role=self.estimator.role,
                                       models=[
                                           self.preprocessor.create_model(),
                                           self.estimator.create_model()
                                       ])
        pipeline_model_step = ModelStep(StepId.CreatePipelineModel.value,
                                        instance_type=train_instance_type,
                                        model=preprocessor_model,
                                        model_name=default_name)
        pipeline_model_step.parameters = self.pipeline_model_config(
            train_instance_type, pipeline_model)

        deployable_model = Model(model_data='', image='')

        # Deployment
        endpoint_config_step = EndpointConfigStep(
            StepId.ConfigureEndpoint.value,
            endpoint_config_name=default_name,
            model_name=default_name,
            initial_instance_count=train_instance_count,
            instance_type=train_instance_type)

        deploy_step = EndpointStep(
            StepId.Deploy.value,
            endpoint_name=default_name,
            endpoint_config_name=default_name,
        )

        return Chain([
            preprocessor_train_step, preprocessor_model_step,
            preprocessor_transform_step, training_step, pipeline_model_step,
            endpoint_config_step, deploy_step
        ])
Exemplo n.º 25
0
class StepfunctionsWorkflow(DataJobBase):
    """Class that defines the methods to create and execute an orchestration
    using the step functions sdk.

    example:

        with StepfunctionsWorkflow("techskills-parser") as tech_skills_parser_orchestration:

            some-glue-job-1 >> [some-glue-job-2,some-glue-job-3] >> some-glue-job-4

        tech_skills_parser_orchestration.execute()
    """
    def __init__(
        self,
        datajob_stack: core.Construct,
        name: str,
        notification: Union[str, list] = None,
        role: iam.Role = None,
        region: str = None,
        **kwargs,
    ):
        super().__init__(datajob_stack, name, **kwargs)
        self.workflow = None
        self.chain_of_tasks = None
        self.role = self.get_role(
            role=role,
            datajob_stack=datajob_stack,
            unique_name=self.unique_name,
            service_principal="states.amazonaws.com",
        )
        self.region = (region if region is not None else
                       os.environ.get("AWS_DEFAULT_REGION"))
        self.notification = self._setup_notification(notification)
        self.kwargs = kwargs
        # init directed graph dict where values are a set.
        # we do it like this so that we can use toposort.
        self.directed_graph = defaultdict(set)

    def add_task(self, some_task: DataJobBase) -> object:
        """get the stepfunctions  task,  sfn_task, we would like to
        orchestrate."""
        return some_task.sfn_task

    def add_parallel_tasks(self,
                           parallel_tasks: Iterator[DataJobBase]) -> Parallel:
        """add tasks in parallel (wrapped in a list) to the workflow we would
        like to orchestrate."""
        parallel_pipelines = Parallel(state_id=uuid.uuid4().hex)
        for a_task in parallel_tasks:
            logger.debug(f"adding parallel task {a_task}")
            sfn_task = self.add_task(a_task)
            parallel_pipelines.add_branch(sfn_task)
        return parallel_pipelines

    def _is_one_task(self, directed_graph_toposorted):
        """If we have length of 2 and the second is an Ellipsis object we have
        scheduled 1 task.

        example:
            some_task >> ...

        :param directed_graph_toposorted: a toposorted graph, a graph with all the sorted tasks
        :return: boolean
        """
        return len(directed_graph_toposorted) == 2 and isinstance(
            list(directed_graph_toposorted[1])[0], type(Ellipsis))

    def _construct_toposorted_chain_of_tasks(self) -> Chain:
        """Take the directed graph and toposort so that we can efficiently
        organize our workflow, i.e. parallelize where possible.

        if we have 2 elements where one of both is an Ellipsis object we need to orchestrate just 1 job.
        In the other case we will loop over the toposorted dag and assign a stepfunctions task
        or assign multiple tasks in parallel.

        Returns: toposorted chain of tasks
        """
        self.chain_of_tasks = Chain()
        directed_graph_toposorted = list(toposort.toposort(
            self.directed_graph))
        if self._is_one_task(
                directed_graph_toposorted=directed_graph_toposorted):
            sfn_task = self.add_task(next(iter(directed_graph_toposorted[0])))
            self.chain_of_tasks.append(sfn_task)
        else:
            for element in directed_graph_toposorted:
                if len(element) == 1:
                    sfn_task = self.add_task(next(iter(element)))
                elif len(element) > 1:
                    sfn_task = self.add_parallel_tasks(element)
                else:
                    raise StepfunctionsWorkflowException(
                        "cannot have an index in the directed graph with 0 elements"
                    )
                self.chain_of_tasks.append(sfn_task)
        return self.chain_of_tasks

    def build_workflow(self):
        """create a step functions workflow from the chain_of_tasks."""
        self.chain_of_tasks = self._construct_toposorted_chain_of_tasks()
        logger.debug("creating a chain from all the different steps.")
        self.chain_of_tasks = self._integrate_notification_in_workflow(
            chain_of_tasks=self.chain_of_tasks)
        logger.debug(f"creating a workflow with name {self.unique_name}")
        sfn_client = boto3.client("stepfunctions")
        self.workflow = Workflow(
            name=self.unique_name,
            definition=self.chain_of_tasks,
            role=self.role.role_arn,
            client=sfn_client,
            **self.kwargs,
        )

    def create(self):
        """create sfn stack."""
        import json

        cfn_template = json.dumps(self.workflow.definition.to_dict())
        CfnStateMachine(
            scope=self.datajob_stack,
            id=self.unique_name,
            state_machine_name=self.unique_name,
            role_arn=self.role.role_arn,
            definition_string=cfn_template,
            **self.kwargs,
        )

    def _setup_notification(
            self, notification: Union[str, list]) -> Union[SnsTopic, None]:
        """Create a SnsTopic if the notification parameter is defined.

        :param notification: email address as string or list of email addresses to be subscribed.
        :return:
        """
        if notification is not None:
            name = f"{self.name}-notification"
            return SnsTopic(self.datajob_stack, name, notification)

    def _integrate_notification_in_workflow(self,
                                            chain_of_tasks: Chain) -> Chain:
        """If a notification is defined we configure an SNS with email
        subscription to alert the user if the stepfunctions workflow failed or
        succeeded.

        :param chain_of_tasks: the workflow definition that contains all the steps we want to execute.
        :return: if notification is set, we adapt the workflow to include an SnsPublishStep on failure or on success.
        If notification is not set, we return the workflow as we received it.
        """
        if self.notification:
            logger.debug(
                "A notification is configured, "
                "implementing a notification on Error or when the stepfunctions workflow succeeds."
            )
            failure_notification = SnsPublishStep(
                "FailureNotification",
                parameters={
                    "TopicArn":
                    self.notification.get_topic_arn(),
                    "Message":
                    f"Stepfunctions workflow {self.unique_name} Failed.",
                },
            )
            pass_notification = SnsPublishStep(
                "SuccessNotification",
                parameters={
                    "TopicArn":
                    self.notification.get_topic_arn(),
                    "Message":
                    f"Stepfunctions workflow {self.unique_name} Succeeded.",
                },
            )

            catch_error = Catch(error_equals=["States.ALL"],
                                next_step=failure_notification)
            workflow_with_notification = Parallel(state_id="notification")
            workflow_with_notification.add_branch(chain_of_tasks)
            workflow_with_notification.add_catch(catch_error)
            workflow_with_notification.next(pass_notification)
            return Chain([workflow_with_notification])
        logger.debug(
            "No notification is configured, returning the workflow definition."
        )
        return chain_of_tasks

    def __enter__(self):
        """first steps we have to do when entering the context manager."""
        logger.info(f"creating step functions workflow for {self.unique_name}")
        _set_workflow(self)
        return self

    def __exit__(self, exc_type, exc_value, traceback) -> None:
        """steps we have to do when exiting the context manager."""
        self.build_workflow()
        _set_workflow(None)
        logger.info(f"step functions workflow {self.unique_name} created")
Exemplo n.º 26
0
def define_monitor_pipeline(
    account,
    region,
    sm_role,
    workflow_execution_role,
    data_pipeline_name,
    return_yaml=True,
    dump_yaml_file="templates/sagemaker_data_pipeline.yaml",
):
    """
    Return YAML definition of the training pipeline, which consists of multiple
    Amazon StepFunction steps

    sm_role:                    ARN of the SageMaker execution role
    workflow_execution_role:    ARN of the StepFunction execution role
    return_yaml:                Return YAML representation or not, if False,
                                it returns an instance of
                                    `stepfunctions.workflow.WorkflowObject`
    dump_yaml_file:             If not None, a YAML file will be generated at
                                    this file location

    """

    # Pass required parameters dynamically for each execution using placeholders.
    execution_input = ExecutionInput(
        schema={
            "PreprocessingJobName": str,
            "PreprocessingInferJobName": str,
            "PreprocessingCodeURL": str,
            "MonitorTrainOutputURL": str,
            "MonitorInferOutputURL": str,
            "InputDataURL": str,
            "InferDataURL": str,
        })
    """
    Custom container for monitoring
    """
    image = "mlmax-processing-monitor"
    img_uri = f"{account}.dkr.ecr.{region}.amazonaws.com/{image}:latest"
    processor = ScriptProcessor(
        image_uri=img_uri,
        role=sm_role,
        instance_count=16,
        instance_type="ml.m5.2xlarge",
        command=["/opt/program/submit"],
        max_runtime_in_seconds=3600,
        env={"mode": "python"},
    )

    #############################
    # Baseline
    #############################
    # Create ProcessingInputs and ProcessingOutputs objects for Inputs and
    # Outputs respectively for the SageMaker Processing Job
    inputs = [
        ProcessingInput(
            source=execution_input["InputDataURL"],
            destination="/opt/ml/processing/train_input",
            input_name="train-input-data",
        ),
        ProcessingInput(
            source=execution_input["PreprocessingCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
    ]

    outputs = [
        ProcessingOutput(
            source="/opt/ml/processing/profiling/inference",
            destination=execution_input["MonitorTrainOutputURL"],
            output_name="baseline-data",
        )
    ]

    processing_step = ProcessingStep(
        "SageMaker pre-processing Baseline",
        processor=processor,
        job_name=execution_input["PreprocessingJobName"],
        inputs=inputs,
        outputs=outputs,
        container_arguments=[
            "--train-test-split-ratio", "0.2", "--mode", "train"
        ],
        container_entrypoint=[
            "python3",
            "/opt/ml/processing/input/code/monitoring.py",
        ],
    )

    #############################
    # Inference
    #############################
    inputs = [
        ProcessingInput(
            source=execution_input["InferDataURL"],
            destination="/opt/ml/processing/infer_input",
            input_name="infer-input-data",
        ),
        ProcessingInput(
            source=execution_input["MonitorTrainOutputURL"],
            destination="/opt/ml/processing/profiling",
            input_name="baseline-data",
        ),
        ProcessingInput(
            source=execution_input["PreprocessingCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
    ]

    outputs = [
        ProcessingOutput(
            source="/opt/ml/processing/profiling/inference",
            destination=execution_input["MonitorInferOutputURL"],
            output_name="monitor-output",
        )
    ]

    processing_step_inference = ProcessingStep(
        "SageMaker pre-processing Inference",
        processor=processor,
        job_name=execution_input["PreprocessingInferJobName"],
        inputs=inputs,
        outputs=outputs,
        container_arguments=["--mode", "infer"],
        container_entrypoint=[
            "python3",
            "/opt/ml/processing/input/code/monitoring.py",
        ],
    )

    # Create Fail state to mark the workflow failed in case any of the steps fail.
    failed_state_sagemaker_processing_failure = stepfunctions.steps.states.Fail(
        "ML Workflow failed", cause="SageMakerProcessingJobFailed")

    # Add the Error handling in the workflow
    catch_state_processing = stepfunctions.steps.states.Catch(
        error_equals=["States.TaskFailed"],
        next_step=failed_state_sagemaker_processing_failure,
    )
    processing_step.add_catch(catch_state_processing)
    processing_step_inference.add_catch(catch_state_processing)

    # Create the Workflow
    workflow_graph = Chain([processing_step, processing_step_inference])
    data_pipeline = Workflow(
        name=data_pipeline_name,
        definition=workflow_graph,
        role=workflow_execution_role,
    )
    return data_pipeline
Exemplo n.º 27
0
def test_workflow_with_placeholders():
    workflow_input = ExecutionInput()

    test_step_01 = Pass(state_id='StateOne',
                        parameters={
                            'ParamA': workflow_input['Key02']['Key03'],
                            'ParamD': workflow_input['Key01']['Key03'],
                        })

    test_step_02 = Pass(state_id='StateTwo',
                        parameters={
                            'ParamC': workflow_input["Key05"],
                            "ParamB": "SampleValueB",
                            "ParamE":
                            test_step_01.output()["Response"]["Key04"]
                        })

    test_step_03 = Pass(state_id='StateThree',
                        parameters={
                            'ParamG': "SampleValueG",
                            "ParamF": workflow_input["Key06"],
                            "ParamH": "SampleValueH"
                        })

    workflow_definition = Chain([test_step_01, test_step_02, test_step_03])

    result = Graph(workflow_definition).to_dict()

    expected_workflow_repr = {
        "StartAt": "StateOne",
        "States": {
            "StateOne": {
                "Type": "Pass",
                "Parameters": {
                    "ParamA.$": "$$.Execution.Input['Key02']['Key03']",
                    "ParamD.$": "$$.Execution.Input['Key01']['Key03']"
                },
                "Next": "StateTwo"
            },
            "StateTwo": {
                "Type": "Pass",
                "Parameters": {
                    "ParamC.$": "$$.Execution.Input['Key05']",
                    "ParamB": "SampleValueB",
                    "ParamE.$": "$['Response']['Key04']"
                },
                "Next": "StateThree"
            },
            "StateThree": {
                "Type": "Pass",
                "Parameters": {
                    "ParamG": "SampleValueG",
                    "ParamF.$": "$$.Execution.Input['Key06']",
                    "ParamH": "SampleValueH"
                },
                "End": True
            }
        }
    }

    assert result == expected_workflow_repr
def test_append_states_after_terminal_state_will_fail():
    with pytest.raises(ValueError):
        chain = Chain()
        chain.append(Pass('Pass'))
        chain.append(Fail('Fail'))
        chain.append(Pass('Pass2'))

    with pytest.raises(ValueError):
        chain = Chain()
        chain.append(Pass('Pass'))
        chain.append(Succeed('Succeed'))
        chain.append(Pass('Pass2'))

    with pytest.raises(ValueError):
        chain = Chain()
        chain.append(Pass('Pass'))
        chain.append(Choice('Choice'))
        chain.append(Pass('Pass2'))
def test_chaining_steps():
    s1 = Pass('Step - One')
    s2 = Pass('Step - Two')
    s3 = Pass('Step - Three')

    Chain([s1, s2])
    assert s1.next_step == s2
    assert s2.next_step is None

    chain1 = Chain([s2, s3])
    assert s2.next_step == s3

    chain2 = Chain([s1, s3])
    assert s1.next_step == s3
    assert s2.next_step == s1.next_step
    with pytest.raises(DuplicateStatesInChain):
        chain2.append(s3)

    with pytest.raises(DuplicateStatesInChain):
        chain3 = Chain([chain1, chain2])

    s1.next(s2)
    chain3 = Chain([s3, s1])
    assert chain3.steps == [s3, s1]
    assert s3.next_step == s1
    assert s1.next_step == s2
    assert s2.next_step == s3

    Chain([Chain([s3]), Chain([s1])])

    with pytest.raises(DuplicateStatesInChain):
        Chain([Chain([s1, s2, s1]), s3])
        Chain([s1, s2, s1, s3])
    Chain([Chain([s1, s2]), s3])
    assert s1.next_step == s2
    assert s2.next_step == s3
Exemplo n.º 30
0
def define_inference_pipeline(
    sm_role,
    workflow_execution_role,
    inference_pipeline_name,
    return_yaml=True,
    dump_yaml_file="templates/sagemaker_inference_pipeline.yaml",
    kms_key_id=None,
):
    """
    Return YAML definition of the training pipeline, which consists of multiple
    Amazon StepFunction steps

    sm_role:                    ARN of the SageMaker execution role
    workflow_execution_role:    ARN of the StepFunction execution role
    return_yaml:                Return YAML representation or not, if False,
                     it returns an instance of `stepfunctions.workflow.WorkflowObject`
    dump_yaml_file:  If not None, a YAML file will be generated at this file location

    """

    # Pass required parameters dynamically for each execution using placeholders.
    execution_input = ExecutionInput(
        schema={
            "InputDataURL": str,
            "PreprocessingJobName": str,
            "InferenceJobName": str,
            "ProcModelS3": str,
            "PreprocessingCodeURL": str,
            "InferenceCodeURL": str,
            "ModelS3": str,
            "PreprocessedTrainDataURL": str,
            "PreprocessedTestDataURL": str,
            "OutputPathURL": str,
        })
    """
    Create Preprocessing Model from model artifact.
    """
    # sagemaker_session = sagemaker.Session()

    sklearn_processor = SKLearnProcessor(
        framework_version="0.20.0",
        role=sm_role,
        instance_type="ml.m5.xlarge",
        instance_count=1,
        max_runtime_in_seconds=1200,
    )
    # Create ProcessingInputs and ProcessingOutputs objects for Inputs and
    # Outputs respectively for the SageMaker Processing Job
    inputs = [
        ProcessingInput(
            source=execution_input["InputDataURL"],
            destination="/opt/ml/processing/input",
            input_name="input-1",
        ),
        ProcessingInput(
            source=execution_input["PreprocessingCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
        ProcessingInput(
            source=execution_input["ProcModelS3"],
            destination="/opt/ml/processing/model",
            input_name="proc_model",
        ),
    ]

    outputs = [
        ProcessingOutput(
            source="/opt/ml/processing/test",
            destination=execution_input["PreprocessedTestDataURL"],
            output_name="test_data",
        ),
    ]

    processing_step = ProcessingStep(
        "SageMaker pre-processing step",
        processor=sklearn_processor,
        job_name=execution_input["PreprocessingJobName"],
        inputs=inputs,
        outputs=outputs,
        container_arguments=["--mode", "infer"],
        container_entrypoint=[
            "python3",
            "/opt/ml/processing/input/code/preprocessing.py",
        ],
        kms_key_id=kms_key_id,
    )
    """
    Create inference with sklearn processing step.

    Inputs are the preprocessed data S3 URL, the inference code S3 URL, and
    the model S3 URL. Output is the inferred data.
    """
    sklearn_processor2 = SKLearnProcessor(
        framework_version="0.20.0",
        role=sm_role,
        instance_type="ml.m5.xlarge",
        instance_count=1,
        max_runtime_in_seconds=1200,
    )
    inputs = [
        ProcessingInput(
            source=execution_input["PreprocessedTestDataURL"],
            destination="/opt/ml/processing/input",
            input_name="input-1",
        ),
        ProcessingInput(
            source=execution_input["InferenceCodeURL"],
            destination="/opt/ml/processing/input/code",
            input_name="code",
        ),
        ProcessingInput(
            source=execution_input["ModelS3"],
            destination="/opt/ml/processing/model",
            input_name="model",
        ),
    ]

    outputs = [
        ProcessingOutput(
            source="/opt/ml/processing/test",
            destination=execution_input["OutputPathURL"],
            output_name="test_data",
        ),
    ]

    inference_step = ProcessingStep(
        "SageMaker inference step",
        processor=sklearn_processor2,
        job_name=execution_input["InferenceJobName"],
        inputs=inputs,
        outputs=outputs,
        container_entrypoint=[
            "python3",
            "/opt/ml/processing/input/code/inference.py",
        ],
        kms_key_id=kms_key_id,
    )

    # Create Fail state to mark the workflow failed in case any of the steps fail.
    failed_state_sagemaker_processing_failure = stepfunctions.steps.states.Fail(
        "ML Workflow failed", cause="SageMakerProcessingJobFailed")

    # Add the Error handling in the workflow
    catch_state_processing = stepfunctions.steps.states.Catch(
        error_equals=["States.TaskFailed"],
        next_step=failed_state_sagemaker_processing_failure,
    )

    processing_step.add_catch(catch_state_processing)
    inference_step.add_catch(catch_state_processing)

    # Create the Workflow
    workflow_graph = Chain([processing_step, inference_step])
    inference_pipeline = Workflow(
        name=inference_pipeline_name,
        definition=workflow_graph,
        role=workflow_execution_role,
    )
    return inference_pipeline