def test_model_step(trained_estimator, sfn_client, sagemaker_session,
                    sfn_role_arn):
    # Build workflow definition
    model_name = generate_job_name()
    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=model_name)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        # Execute workflow
        execution = workflow.execute()
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)
def _get_endpoint_name(execution_output):
    endpoint_arn = execution_output.get('EndpointArn', None)
    endpoint_name = None

    if endpoint_arn is not None:
        resource_name = get_resource_name_from_arn(endpoint_arn)
        endpoint_name = resource_name.split("/")[-1]
    
    return endpoint_name
def test_model_step_with_placeholders(trained_estimator, sfn_client,
                                      sagemaker_session, sfn_role_arn):
    # Build workflow definition
    execution_input = ExecutionInput(schema={
        'ModelName': str,
        'Mode': str,
        'Tags': list
    })

    parameters = {
        'PrimaryContainer': {
            'Mode': execution_input['Mode']
        },
        'Tags': execution_input['Tags']
    }

    model_step = ModelStep('create_model_step',
                           model=trained_estimator.create_model(),
                           model_name=execution_input['ModelName'],
                           parameters=parameters)
    model_step.add_retry(SAGEMAKER_RETRY_STRATEGY)
    workflow_graph = Chain([model_step])

    with timeout(minutes=DEFAULT_TIMEOUT_MINUTES):
        # Create workflow and check definition
        workflow = create_workflow_and_check_definition(
            workflow_graph=workflow_graph,
            workflow_name=unique_name_from_base(
                "integ-test-model-step-workflow"),
            sfn_client=sfn_client,
            sfn_role_arn=sfn_role_arn)

        inputs = {
            'ModelName': generate_job_name(),
            'Mode': 'SingleModel',
            'Tags': [{
                'Key': 'Environment',
                'Value': 'test'
            }]
        }

        # Execute workflow
        execution = workflow.execute(inputs=inputs)
        execution_output = execution.get_output(wait=True)

        # Check workflow output
        assert execution_output.get("ModelArn") is not None
        assert execution_output["SdkHttpMetadata"]["HttpStatusCode"] == 200

        # Cleanup
        state_machine_delete_wait(sfn_client, workflow.state_machine_arn)
        model_name = get_resource_name_from_arn(
            execution_output.get("ModelArn")).split("/")[1]
        delete_sagemaker_model(model_name, sagemaker_session)