def test_task_state_creation(): task_state = Task('Task', resource='arn:aws:lambda:us-east-1:1234567890:function:StartLambda') task_state.add_retry(Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2)) task_state.add_retry(Retry(error_equals=['ErrorC'], interval_seconds=5)) task_state.add_catch(Catch(error_equals=['States.ALL'], next_step=Pass('End State'))) assert task_state.type == 'Task' assert len(task_state.retries) == 2 assert len(task_state.catches) == 1 assert task_state.to_dict() == { 'Type': 'Task', 'Resource': 'arn:aws:lambda:us-east-1:1234567890:function:StartLambda', 'Retry': [ { 'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2 }, { 'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5 } ], 'Catch': [ { 'ErrorEquals': ['States.ALL'], 'Next': 'End State' } ], 'End': True }
def test_catch_fail_for_unsupported_state(): s1 = Pass('Step - One') with pytest.raises(ValueError): s1.add_retry( Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2))
# Copyright 2019 Amazon.com, Inc. or its affiliates. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"). # You may not use this file except in compliance with the License. # A copy of the License is located at # # http://www.apache.org/licenses/LICENSE-2.0 # # or in the "license" file accompanying this file. This file is distributed # on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either # express or implied. See the License for the specific language governing # permissions and limitations under the License. from __future__ import absolute_import import os from stepfunctions.steps import Retry DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") DEFAULT_TIMEOUT_MINUTES = 25 # Default retry strategy for SageMaker steps used in integration tests SAGEMAKER_RETRY_STRATEGY = Retry( error_equals=["SageMaker.AmazonSageMakerException"], interval_seconds=5, max_attempts=5, backoff_rate=2)
assert 'ResultPath' not in task_state.to_dict() assert 'InputPath' not in task_state.to_dict() assert 'OutputPath' not in task_state.to_dict() def test_default_paths_not_converted_to_null(): task_state = Task( 'Task', resource='arn:aws:lambda:us-east-1:1234567890:function:StartLambda') assert '"ResultPath": null' not in task_state.to_json() assert '"InputPath": null' not in task_state.to_json() assert '"OutputPath": null' not in task_state.to_json() RETRY = Retry(error_equals=['ErrorA', 'ErrorB'], interval_seconds=1, max_attempts=2, backoff_rate=2) RETRIES = [RETRY, Retry(error_equals=['ErrorC'], interval_seconds=5)] EXPECTED_RETRY = [{ 'ErrorEquals': ['ErrorA', 'ErrorB'], 'IntervalSeconds': 1, 'BackoffRate': 2, 'MaxAttempts': 2 }] EXPECTED_RETRIES = EXPECTED_RETRY + [{ 'ErrorEquals': ['ErrorC'], 'IntervalSeconds': 5 }] @pytest.mark.parametrize("retry, expected_retry", [
'FunctionName': 'CreateAutopilotJob', 'Payload': { 'Configuration': { 'AutoMLJobName': execution_input['AutoMLJobName'], 'S3InputData': execution_input['S3InputData'], 'IamRole': execution_input['IamRole'], 'TargetColumnName': execution_input['TargetColumnName'], 'S3OutputData': execution_input['S3OutputData'], 'Tags': execution_input['Tags'] } } }) create_autopilot_job_step.add_retry( Retry(error_equals=["States.TaskFailed"], interval_seconds=15, max_attempts=2, backoff_rate=4.0)) create_autopilot_job_step.add_catch( Catch(error_equals=["States.TaskFailed"], next_step=workflow_failure)) check_autopilot_job_status = LambdaStep( 'CheckAutopilotJobStatus', parameters={ 'FunctionName': 'CheckAutopilotJobStatus', 'Payload': { 'AutopilotJobName': create_autopilot_job_step.output()['Payload']['AutopilotJobName'] } })