def create_state_machine(self, lambda_functions, page_sqs):

        task_wrapup = aws_stepfunctions.Task(
            self, "task_wrapup",
            task = aws_stepfunctions_tasks.RunLambdaTask(lambda_functions["wrapup"])
        )

        tast_analyze_with_scale = aws_stepfunctions.Task(
            self, "AnalyzeWithScale",
            task=  aws_stepfunctions_tasks.SendToQueue(
                queue = page_sqs, 
                message_body = aws_stepfunctions.TaskInput.from_object(
                    {
                        "token": aws_stepfunctions.Context.task_token,
                        "id.$": "$.id",
                        "bucket.$": "$.bucket",
                        "original_upload_pdf.$": "$.original_upload_pdf",
                        "SAGEMAKER_WORKFLOW_AUGMENTED_AI_ARN.$": "$.SAGEMAKER_WORKFLOW_AUGMENTED_AI_ARN",
                        "key.$": "$.key"
                    }
                ),
                delay=None, 
                integration_pattern=aws_stepfunctions.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN
            )
        )

        process_map = aws_stepfunctions.Map(
            self, "Process_Map",
            items_path = "$.image_keys",
            result_path="DISCARD",
            parameters = {
                "id.$": "$.id",
                "bucket.$": "$.bucket",
                "original_upload_pdf.$": "$.original_upload_pdf",
                "SAGEMAKER_WORKFLOW_AUGMENTED_AI_ARN.$": "$.SAGEMAKER_WORKFLOW_AUGMENTED_AI_ARN",
                "key.$": "$$.Map.Item.Value"
            }
        ).iterator(tast_analyze_with_scale)

        definition = process_map.next(task_wrapup)

        aws_stepfunctions.StateMachine(
            scope = self, 
            id = "multipagepdfa2i_fancy_stepfunction",
            state_machine_name = "multipagepdfa2i_fancy_stepfunction",
            definition=definition
        )
    def __init__(self, scope, id, *args, **kwargs):
        super().__init__(scope, id, *args, **kwargs)

        # Buckets
        source_bucket = s3.Bucket(self, "SourceBucket")
        dest_bucket = s3.Bucket(self, "DestinationBucket")
        processing_bucket = s3.Bucket(self, "ProcessingBucket")

        # Lambda Functions
        generate_workflow_input_lambda = aws_lambda.Function(
            self, "GenerateWorkflowInputFunction",
            code=aws_lambda.Code.from_asset(str(DIST_PATH)),
            runtime=aws_lambda.Runtime.PYTHON_3_8,
            handler="generate_workflow_input.lambda_handler",
            environment={
                "InputBucketName": source_bucket.bucket_name,
                "ProcessingBucketName": processing_bucket.bucket_name,
                "OutputBucketName": dest_bucket.bucket_name
            }
        )
        check_workflow_ready_lambda = aws_lambda.Function(
            self, "CheckWorkflowReadyFunction",
            code=aws_lambda.Code.from_asset(str(DIST_PATH)),
            runtime=aws_lambda.Runtime.PYTHON_3_8,
            handler="check_workflow_ready.lambda_handler"
        )
        string_replace_lambda = aws_lambda.Function(
            self, "StringReplaceFunction",
            code=aws_lambda.Code.from_asset(str(DIST_PATH)),
            runtime=aws_lambda.Runtime.PYTHON_3_8,
            handler="string_replace.lambda_handler"
        )
        calculate_total_earnings_lambda = aws_lambda.Function(
            self, "CalculateTotalEarningsFunction",
            code=aws_lambda.Code.from_asset(str(DIST_PATH)),
            runtime=aws_lambda.Runtime.PYTHON_3_8,
            handler="calculate_total_earnings.lambda_handler"
        )
        convert_csv_to_json_lambda = aws_lambda.Function(
            self, "ConvertCsvToJsonFunction",
            code=aws_lambda.Code.from_asset(str(DIST_PATH)),
            runtime=aws_lambda.Runtime.PYTHON_3_8,
            handler="convert_csv_to_json.lambda_handler"
        )

        # Permissions
        source_bucket.grant_read(check_workflow_ready_lambda)
        source_bucket.grant_read(string_replace_lambda)
        processing_bucket.grant_write(string_replace_lambda)
        processing_bucket.grant_read_write(calculate_total_earnings_lambda)
        processing_bucket.grant_read(convert_csv_to_json_lambda)
        dest_bucket.grant_write(convert_csv_to_json_lambda)

        # Outputs
        core.CfnOutput(self, "SourceBucketName", value=source_bucket.bucket_name)
        core.CfnOutput(self, "DestinationBucketName", value=dest_bucket.bucket_name)
        core.CfnOutput(self, "ProcessingBucketName", value=processing_bucket.bucket_name)
        core.CfnOutput(self, "GenerateWorkflowInputLambda", value=generate_workflow_input_lambda.function_name)
        core.CfnOutput(self, "CheckWorkflowReadyLambda", value=check_workflow_ready_lambda.function_name)
        core.CfnOutput(self, "StringReplaceLambda", value=string_replace_lambda.function_name)
        core.CfnOutput(self, "CalculateTotalEarningsLambda", value=calculate_total_earnings_lambda.function_name)
        core.CfnOutput(self, "ConvertCsvToJsonLambda", value=convert_csv_to_json_lambda.function_name)

        # State Machine
        generate_workflow_input_task = sf_tasks.LambdaInvoke(
            self, "GenerateWorkflowInput",
            lambda_function=generate_workflow_input_lambda,
            payload_response_only=True
        )
        check_workflow_ready_task = sf_tasks.LambdaInvoke(
            self, "CheckWorkflowReady",
            lambda_function=check_workflow_ready_lambda,
            input_path="$.CheckWorkflowReady.Input",
            result_path="$.CheckWorkflowReady.Output",
            payload_response_only=True
        )
        string_replace_task = sf_tasks.LambdaInvoke(
            self, "ReplaceString",
            lambda_function=string_replace_lambda,
            result_path="$.StringReplace.Output",
            payload_response_only=True
        )
        calculate_total_earnings_task = sf_tasks.LambdaInvoke(
            self, "CalculateTotalEarnings",
            lambda_function=calculate_total_earnings_lambda,
            input_path="$.CalculateTotalEarnings.Input",
            result_path="$.CalculateTotalEarnings.Output",
            payload_response_only=True
        )
        convert_csv_to_json_task = sf_tasks.LambdaInvoke(
            self, "ConvertCsvToJson",
            lambda_function=convert_csv_to_json_lambda,
            input_path="$.ConvertCsvToJson.Input",
            result_path="$.ConvertCsvToJson.Output",
            payload_response_only=True
        )

        end_task = sf.Succeed(self, "WorkflowEnd")

        replace_string_parallel = sf.Map(
            self, "ReplaceStringParallel",
            items_path="$.StringReplace.Input",
            result_path="$.StringReplace.Output"
        ).iterator(string_replace_task)

        workflow_steps = sf.Chain.\
            start(replace_string_parallel)\
            .next(calculate_total_earnings_task)\
            .next(convert_csv_to_json_task)\
            .next(end_task)

        run_workflow = sf.Choice(self, "RunWorkflowDecision")\
            .when(sf.Condition.boolean_equals("$.CheckWorkflowReady.Output", True), workflow_steps)\
            .otherwise(end_task)

        hello_workflow_state_machine = sf.StateMachine(
            self, "HelloWorkflowStateMachine",
            definition=sf.Chain\
                .start(generate_workflow_input_task)\
                .next(check_workflow_ready_task)\
                .next(run_workflow)
        )
Пример #3
0
    def __init__(self, app: core.App, id: str, **kwargs) -> None:
        super().__init__(app, id, **kwargs)

        submit_job_activity = sfn.Activity(self, "SubmitJob")
        check_job_activity = sfn.Activity(self, "CheckJob")
        do_mapping_activity1 = sfn.Activity(self, "MapJOb1")
        do_mapping_activity2 = sfn.Activity(self, "MapJOb2")

        submit_job = sfn.Task(
            self,
            "Submit Job",
            task=sfn_tasks.InvokeActivity(submit_job_activity),
            result_path="$.guid",
        )

        task1 = sfn.Task(
            self,
            "Task 1 in Mapping",
            task=sfn_tasks.InvokeActivity(do_mapping_activity1),
            result_path="$.guid",
        )

        task2 = sfn.Task(
            self,
            "Task 2 in Mapping",
            task=sfn_tasks.InvokeActivity(do_mapping_activity2),
            result_path="$.guid",
        )

        wait_x = sfn.Wait(
            self,
            "Wait X Seconds",
            time=sfn.WaitTime.seconds_path('$.wait_time'),
        )
        get_status = sfn.Task(
            self,
            "Get Job Status",
            task=sfn_tasks.InvokeActivity(check_job_activity),
            input_path="$.guid",
            result_path="$.status",
        )
        is_complete = sfn.Choice(self, "Job Complete?")
        job_failed = sfn.Fail(self,
                              "Job Failed",
                              cause="AWS Batch Job Failed",
                              error="DescribeJob returned FAILED")
        final_status = sfn.Task(
            self,
            "Get Final Job Status",
            task=sfn_tasks.InvokeActivity(check_job_activity),
            input_path="$.guid",
        )

        definition_map = task1.next(task2)

        process_map = sfn.Map(self, "Process_map",
                              max_concurrency=10).iterator(definition_map)

        definition = submit_job \
            .next(process_map) \
            .next(wait_x) \
            .next(get_status) \
            .next(is_complete
                  .when(sfn.Condition.string_equals(
                    "$.status", "FAILED"), job_failed)
                  .when(sfn.Condition.string_equals(
                    "$.status", "SUCCEEDED"), final_status)
                  .otherwise(wait_x))

        sfn.StateMachine(
            self,
            "StateMachine",
            definition=definition,
            timeout=core.Duration.seconds(30),
        )
Пример #4
0
    def create_unfurl_statemachine(self):
        map_job = sfn.Map(self,
                          "Unfurl Map",
                          items_path="$.links",
                          max_concurrency=10)
        get_note_job = tasks.LambdaInvoke(
            self,
            "Get Note Job for Unfurl",
            lambda_function=self.step_lambda,
            payload=sfn.TaskInput.from_object({
                "action": "get_note_from_url",
                "url.$": "$.url"
            }),
        )
        get_tf_job = tasks.LambdaInvoke(
            self,
            "Get Text Frequency Job for Unfurl",
            lambda_function=self.step_lambda,
            payload=sfn.TaskInput.from_object({
                "action":
                "update_tf",
                "id.$":
                "$.Payload.id",
                "url.$":
                "$.Payload.url",
                "contentUpdatedAt.$":
                "$.Payload.contentUpdatedAt",
                "isArchived.$":
                "$.Payload.isArchived",
            }),
        )
        get_idf_job = tasks.LambdaInvoke(
            self,
            "Get Inter Document Frequency Job for Unfurl",
            lambda_function=self.step_lambda,
            payload=sfn.TaskInput.from_object({
                "action":
                "update_idf",
                "id.$":
                "$.Payload.id",
                "url.$":
                "$.Payload.url",
                "contentUpdatedAt.$":
                "$.Payload.contentUpdatedAt",
                "isArchived.$":
                "$.Payload.isArchived",
            }),
        )
        get_tfidf_job = tasks.LambdaInvoke(
            self,
            "Get TF*IDF WordCloud Image Job for Unfurl",
            lambda_function=self.step_lambda,
            payload=sfn.TaskInput.from_object({
                "action":
                "update_tfidf_png",
                "id.$":
                "$.Payload.id",
                "url.$":
                "$.Payload.url",
                "contentUpdatedAt.$":
                "$.Payload.contentUpdatedAt",
                "isArchived.$":
                "$.Payload.isArchived",
            }),
        )
        unfurl_job = tasks.LambdaInvoke(
            self,
            "Get Attachment Job",
            lambda_function=self.step_lambda,
            payload=sfn.TaskInput.from_object({
                "action": "unfurl",
                "id.$": "$.Payload.id",
                "url.$": "$.Payload.url",
            }),
        )

        get_tf_job.next(get_idf_job.next(get_tfidf_job.next(unfurl_job)))

        choice_job = sfn.Choice(self, "Check for Update")
        choice_job.when(
            sfn.Condition.and_(
                sfn.Condition.is_timestamp("$.Payload.tfidfPngUpdatedAt"),
                sfn.Condition.timestamp_less_than_json_path(
                    "$.Payload.contentUpdatedAt",
                    "$.Payload.tfidfPngUpdatedAt"),
            ),
            unfurl_job,
        ).when(
            sfn.Condition.and_(
                sfn.Condition.is_timestamp("$.Payload.tfTsvUpdatedAt"),
                sfn.Condition.timestamp_less_than_json_path(
                    "$.Payload.contentUpdatedAt", "$.Payload.tfTsvUpdatedAt"),
            ),
            get_tfidf_job,
        ).otherwise(get_tf_job)

        unfurl_definition = map_job.iterator(get_note_job.next(choice_job))
        self.unfurl_statemachine = sfn.StateMachine(
            self,
            "UnfurlStateMachine",
            definition=unfurl_definition,
            timeout=core.Duration.minutes(20),
            state_machine_type=sfn.StateMachineType.EXPRESS,
            logs=sfn.LogOptions(
                destination=logs.LogGroup(self, "UnfurlStateMachineLogGroup"),
                level=sfn.LogLevel.ERROR,
            ),
        )
    def __init__(self, scope: core.Construct, id: str, QueueDefine="default",TaskDefine="default",LambdaDefine="default", SNSDefine="default",**kwargs):
        super().__init__(scope, id, **kwargs)

        self.Job_String_Split = _sfn.Task(
            self,"String_Split",
            input_path = "$.TaskInfo",
            result_path = "$.JobDetail.String_Split",
            output_path = "$",
            task = _sfn_tasks.RunBatchJob(
                job_name = "String_Split",
                job_definition = TaskDefine.getTaskDefine("String_Split"),
                job_queue = QueueDefine.getComputeQueue("ComputeQueue"),
                container_overrides = _sfn_tasks.ContainerOverrides(
                    environment = {
                        "INPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.INPUT_BUCKET"),
                        "INPUT_KEY":_sfn.Data.string_at("$.BasicParameters.INPUT_KEY"),
                        "OUTPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.OUTPUT_BUCKET"),
                        "OUTPUT_KEY":_sfn.Data.string_at("$.JobParameter.String_Split.OUTPUT_KEY"),
                        "SPLIT_NUM":_sfn.Data.string_at("$.JobParameter.String_Split.SPLIT_NUM")
                    }
                )
            )
        )
        
        self.Job_Map = _sfn.Task(
            self,"Job_Map",
            input_path = "$.TaskInfo",
            result_path = "$.TaskInfo.JobDetail.Job_Map",
            output_path = "$",
            task = _sfn_tasks.RunLambdaTask(LambdaDefine.getLambdaFunction("Get_Job_List")),
        )
        
        self.Job_String_Reverse = _sfn.Task(
            self,"String_Reverse",
            input_path = "$",
            result_path = "$",
            output_path = "$",
            task = _sfn_tasks.RunBatchJob(
                job_name = "String_Reverse",
                job_definition = TaskDefine.getTaskDefine("String_Reverse"),
                job_queue = QueueDefine.getComputeQueue("ComputeQueue"),
                container_overrides = _sfn_tasks.ContainerOverrides(
                    environment = {
                        "INDEX":_sfn.Data.string_at("$.INDEX"),
                        "INPUT_BUCKET":_sfn.Data.string_at("$.INPUT_BUCKET"),
                        "INPUT_KEY":_sfn.Data.string_at("$.INPUT_KEY"),
                        "OUTPUT_BUCKET":_sfn.Data.string_at("$.OUTPUT_BUCKET"),
                        "OUTPUT_KEY":_sfn.Data.string_at("$.String_Reverse.OUTPUT_KEY")
                    }
                )
            )
        )
        
        self.Job_String_Repeat = _sfn.Task(
            self,"String_Repeat",
            input_path = "$",
            result_path = "$",
            output_path = "$",
            task = _sfn_tasks.RunBatchJob(
                job_name = "String_Repeat",
                job_definition = TaskDefine.getTaskDefine("String_Repeat"),
                job_queue = QueueDefine.getComputeQueue("ComputeQueue"),
                container_overrides = _sfn_tasks.ContainerOverrides(
                    environment = {
                        "INDEX":_sfn.Data.string_at("$.INDEX"),
                        "INPUT_BUCKET":_sfn.Data.string_at("$.INPUT_BUCKET"),
                        "INPUT_KEY":_sfn.Data.string_at("$.INPUT_KEY"),
                        "OUTPUT_BUCKET":_sfn.Data.string_at("$.OUTPUT_BUCKET"),
                        "OUTPUT_KEY":_sfn.Data.string_at("$.String_Repeat.OUTPUT_KEY")
                    }
                )
            )
        )
        
        self.Job_String_Process_Repeat = _sfn.Map(
            self, "String_Process_Repeat",
            max_concurrency=50,
            input_path = "$.TaskInfo.JobDetail.Job_Map",
            result_path = "DISCARD",
            items_path = "$.Payload",
            output_path = "$",
        ).iterator(self.Job_String_Repeat)
        
        self.Job_String_Repeat_Merge = _sfn.Task(
            self,"String_Repeat_Merge",
            input_path = "$.TaskInfo",
            result_path = "DISCARD",
            output_path = "$",
            task = _sfn_tasks.RunBatchJob(
                job_name = "String_Repeat_Merge",
                job_definition = TaskDefine.getTaskDefine("String_Merge"),
                job_queue = QueueDefine.getComputeQueue("ComputeQueue"),
                container_overrides = _sfn_tasks.ContainerOverrides(
                    environment = {
                        "PERFIX":_sfn.Data.string_at("$.JobParameter.String_Repeat.Prefix"),
                        "FILE_NAME":_sfn.Data.string_at("$.BasicParameters.INPUT_KEY"),
                        "INPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.INPUT_BUCKET"),
                        "INPUT_KEY":_sfn.Data.string_at("$.JobParameter.String_Repeat.OUTPUT_KEY"),
                        "OUTPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.OUTPUT_BUCKET"),
                        "OUTPUT_KEY":_sfn.Data.string_at("$.JobParameter.String_Repeat.OUTPUT_KEY")
                    }
                )
            )
        )
        
        self.Job_String_Process_Repeat.next(self.Job_String_Repeat_Merge)
        
        self.Job_String_Process_Reverse = _sfn.Map(
            self, "String_Process_Reverse",
            max_concurrency=50,
            input_path = "$.TaskInfo.JobDetail.Job_Map",
            result_path = "DISCARD",
            items_path = "$.Payload",
            output_path = "$",
        ).iterator(self.Job_String_Reverse)
        
        self.Job_String_Reverse_Merge = _sfn.Task(
            self,"String_Reverse_Merge",
            input_path = "$.TaskInfo",
            result_path = "DISCARD",
            output_path = "$",
            task = _sfn_tasks.RunBatchJob(
                job_name = "String_Reverse_Merge",
                job_definition = TaskDefine.getTaskDefine("String_Merge"),
                job_queue = QueueDefine.getComputeQueue("ComputeQueue"),
                container_overrides = _sfn_tasks.ContainerOverrides(
                    environment = {
                        "PERFIX":_sfn.Data.string_at("$.JobParameter.String_Reverse.Prefix"),
                        "FILE_NAME":_sfn.Data.string_at("$.BasicParameters.INPUT_KEY"),
                        "INPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.INPUT_BUCKET"),
                        "INPUT_KEY":_sfn.Data.string_at("$.JobParameter.String_Reverse.OUTPUT_KEY"),
                        "OUTPUT_BUCKET":_sfn.Data.string_at("$.BasicParameters.OUTPUT_BUCKET"),
                        "OUTPUT_KEY":_sfn.Data.string_at("$.JobParameter.String_Reverse.OUTPUT_KEY")
                    }
                )
            )
        )
        
        self.Job_String_Process_Reverse.next(self.Job_String_Reverse_Merge)

        self.Job_Parallel_Process = _sfn.Parallel(
            self,
            'Parallel_Process',
            input_path = "$",
            result_path = "DISCARD"
        )
        
        self.Job_Parallel_Process.branch(self.Job_String_Process_Repeat)
        self.Job_Parallel_Process.branch(self.Job_String_Process_Reverse)
        
        self.Job_Check_Output = _sfn.Task(
            self,"Check_Output",
            input_path = "$.TaskInfo",
            
            result_path = "$.JobDetail.Check_Output",
            output_path = "$.JobDetail.Check_Output.Payload",
            task = _sfn_tasks.RunLambdaTask(LambdaDefine.getLambdaFunction("Get_Output_size")),
        )
        
        self.Job_Is_Complete = _sfn.Choice(
            self, "Is_Complete",
            input_path = "$.TaskInfo",
            output_path = "$"
        )
        
        self.Job_Finish = _sfn.Wait(
            self, "Finish",
            time = _sfn.WaitTime.duration(core.Duration.seconds(5))
        )
        
        self.Job_Notification = _sfn.Task(self, "Notification",
            input_path = "$.TaskInfo",
            result_path = "DISCARD",
            output_path = "$",
            task = _sfn_tasks.PublishToTopic(SNSDefine.getSNSTopic("Topic_Batch_Job_Notification"),
                integration_pattern = _sfn.ServiceIntegrationPattern.FIRE_AND_FORGET,
                message = _sfn.TaskInput.from_data_at("$.JobStatus.Job_Comment"),
                subject = _sfn.Data.string_at("$.JobStatus.SNS_Subject")
            )
        )
        
        self.Job_Failed = _sfn.Wait(
            self, "Failed",
            time = _sfn.WaitTime.duration(core.Duration.seconds(5))
        )
        
        self.statemachine = _sfn.StateMachine(
            self, "StateMachine",
            definition = self.Job_String_Split.next(self.Job_Map) \
                .next(self.Job_Parallel_Process) \
                .next(self.Job_Check_Output) \
                .next(self.Job_Notification) \
                .next(self.Job_Is_Complete \
                    .when(_sfn.Condition.string_equals(
                            "$.JobStatus.OutputStatus", "FAILED"
                        ), self.Job_Failed
                            .next(self.Job_Map)
                        )
                    .when(_sfn.Condition.string_equals(
                            "$.JobStatus.OutputStatus", "SUCCEEDED"
                        ), self.Job_Finish)
                    .otherwise(self.Job_Failed)
                ),
            timeout = core.Duration.hours(1),
        )
Пример #6
0
    def __init__(self, scope: core.Construct, id: str, **kwargs) -> None:
        super().__init__(scope, id, **kwargs)

        # A cache to temporarily hold session info
        session_cache_table = aws_dynamodb.Table(
            self,
            'session_cache_table',
            partition_key={
                'name': 'code',
                'type': aws_dynamodb.AttributeType.STRING
            },
            billing_mode=aws_dynamodb.BillingMode.PAY_PER_REQUEST,
            time_to_live_attribute='expires')

        #--
        #  Secrets
        #--------------------#

        # Twitter secrets are stored external to this stack
        twitter_secret = aws_secretsmanager.Secret.from_secret_attributes(
            self,
            'twitter_secret',
            secret_arn=os.environ['TWITTER_SECRET_ARN'])

        #--
        #  Layers
        #--------------------#

        # Each of these dependencies is used in 2 or more functions, extracted to layer for ease of use
        twitter_layer = aws_lambda.LayerVersion(
            self,
            'twitter_layer',
            code=aws_lambda.AssetCode('layers/twitter_layer'),
            compatible_runtimes=[
                aws_lambda.Runtime.PYTHON_2_7, aws_lambda.Runtime.PYTHON_3_6
            ])

        boto_layer = aws_lambda.LayerVersion(
            self,
            'boto_layer',
            code=aws_lambda.AssetCode('layers/boto_layer'),
            compatible_runtimes=[aws_lambda.Runtime.PYTHON_3_6])

        #--
        #  Functions
        #--------------------#

        # Handles CRC validation requests from Twitter
        twitter_crc_func = aws_lambda.Function(
            self,
            "twitter_crc_func",
            code=aws_lambda.AssetCode('functions/twitter_crc_func'),
            handler="lambda.handler",
            layers=[twitter_layer],
            runtime=aws_lambda.Runtime.PYTHON_2_7,
            environment={'twitter_secret': twitter_secret.secret_arn})

        # Grant this function the ability to read Twitter credentials
        twitter_secret.grant_read(twitter_crc_func.role)

        # Handle schedule requests from Twitter
        twitter_webhook_func = aws_lambda.Function(
            self,
            "twitter_webhook_func",
            code=aws_lambda.AssetCode('functions/twitter_webhook_func'),
            handler="lambda.handler",
            layers=[boto_layer, twitter_layer],
            runtime=aws_lambda.Runtime.PYTHON_3_6,
            environment={'twitter_secret': twitter_secret.secret_arn})

        # Grant this function permission to read Twitter credentials
        twitter_secret.grant_read(twitter_webhook_func.role)

        # Grant this function permission to publish tweets to EventBridge
        twitter_webhook_func.add_to_role_policy(
            aws_iam.PolicyStatement(actions=["events:PutEvents"],
                                    resources=["*"]))

        # Use API Gateway as the webhook endpoint
        twitter_api = aws_apigateway.LambdaRestApi(
            self, 'twitter_api', handler=twitter_webhook_func, proxy=False)

        # Tweets are POSTed to the endpoint
        twitter_api.root.add_method('POST')

        # Handles twitter CRC validation requests via GET to the webhook
        twitter_api.root.add_method(
            'GET', aws_apigateway.LambdaIntegration(twitter_crc_func))

        # Extract relevant info from the tweet, including session codes
        parse_tweet_func = aws_lambda.Function(
            self,
            "parse_tweet_func",
            code=aws_lambda.AssetCode('functions/parse_tweet_func'),
            handler="lambda.handler",
            runtime=aws_lambda.Runtime.PYTHON_3_6)

        # Get session information for requested codes
        get_sessions_func = aws_lambda.Function(
            self,
            "get_sessions_func",
            code=aws_lambda.AssetCode('functions/get_sessions_func'),
            handler="lambda.handler",
            runtime=aws_lambda.Runtime.PYTHON_3_6,
            timeout=core.Duration.seconds(60),
            layers=[boto_layer],
            environment={
                'CACHE_TABLE': session_cache_table.table_name,
                'LOCAL_CACHE_TTL':
                str(1 * 60 * 60),  # Cache sessions locally for 1 hour
                'REMOTE_CACHE_TTL': str(12 * 60 * 60)
            })  # Cache sessions removely for 12 hours

        # This functions needs permissions to read and write to the table
        session_cache_table.grant_write_data(get_sessions_func)
        session_cache_table.grant_read_data(get_sessions_func)

        # Create a schedule without conflicts
        create_schedule_func = aws_lambda.Function(
            self,
            "create_schedule_func",
            code=aws_lambda.AssetCode('functions/create_schedule_func'),
            handler="lambda.handler",
            runtime=aws_lambda.Runtime.PYTHON_3_6,
            timeout=core.Duration.seconds(60))

        # Tweet the response to the user
        tweet_schedule_func = aws_lambda.Function(
            self,
            "tweet_schedule_func",
            code=aws_lambda.AssetCode('functions/tweet_schedule_func'),
            handler="lambda.handler",
            layers=[boto_layer, twitter_layer],
            runtime=aws_lambda.Runtime.PYTHON_3_6,
            environment={'twitter_secret': twitter_secret.secret_arn})
        twitter_secret.grant_read(tweet_schedule_func.role)

        #--
        #  States
        #--------------------#

        # Step 4
        tweet_schedule_job = aws_stepfunctions.Task(
            self,
            'tweet_schedule_job',
            task=aws_stepfunctions_tasks.InvokeFunction(tweet_schedule_func))

        # Step 3
        create_schedule_job = aws_stepfunctions.Task(
            self,
            'create_schedule_job',
            task=aws_stepfunctions_tasks.InvokeFunction(create_schedule_func),
            input_path="$.sessions",
            result_path="$.schedule")
        create_schedule_job.next(tweet_schedule_job)

        # Step 2 - Get associated sessions (scrape or cache)
        get_sessions_job = aws_stepfunctions.Task(
            self,
            'get_sessions_job',
            task=aws_stepfunctions_tasks.InvokeFunction(get_sessions_func))

        # Prepare to get session info in parallel using the Map state
        get_sessions_map = aws_stepfunctions.Map(self,
                                                 'get_sessions_map',
                                                 items_path="$.codes",
                                                 result_path="$.sessions")
        get_sessions_map.iterator(get_sessions_job)
        get_sessions_map.next(create_schedule_job)

        # Shortcut if no session codes are supplied
        check_num_codes = aws_stepfunctions.Choice(self, 'check_num_codes')
        check_num_codes.when(
            aws_stepfunctions.Condition.number_greater_than('$.num_codes', 0),
            get_sessions_map)
        check_num_codes.otherwise(aws_stepfunctions.Succeed(self, "no_codes"))

        # Step 1 - Parse incoming tweet and prepare for scheduling
        parse_tweet_job = aws_stepfunctions.Task(
            self,
            'parse_tweet_job',
            task=aws_stepfunctions_tasks.InvokeFunction(parse_tweet_func))
        parse_tweet_job.next(check_num_codes)

        #--
        #  State Machines
        #--------------------#

        schedule_machine = aws_stepfunctions.StateMachine(
            self, "schedule_machine", definition=parse_tweet_job)

        # A rule to filter reInventSched tweet events
        reinvent_sched_rule = aws_events.Rule(
            self,
            "reinvent_sched_rule",
            event_pattern={"source": ["reInventSched"]})

        # Matching events start the image pipline
        reinvent_sched_rule.add_target(
            aws_events_targets.SfnStateMachine(
                schedule_machine,
                input=aws_events.RuleTargetInput.from_event_path("$.detail")))
Пример #7
0
    def __init__(self, app: core.App, id: str, props, **kwargs) -> None:
        super().__init__(app, id, **kwargs)

        run_data_bucket_name = ''
        run_data_bucket = s3.Bucket.from_bucket_name(
            self, run_data_bucket_name, bucket_name=run_data_bucket_name)

        # IAM roles for the lambda functions
        lambda_role = iam.Role(
            self,
            'EchoTesLambdaRole',
            assumed_by=iam.ServicePrincipal('lambda.amazonaws.com'),
            managed_policies=[
                iam.ManagedPolicy.from_aws_managed_policy_name(
                    'service-role/AWSLambdaBasicExecutionRole')
            ])

        copy_lambda_role = iam.Role(
            self,
            'CopyToS3LambdaRole',
            assumed_by=iam.ServicePrincipal('lambda.amazonaws.com'),
            managed_policies=[
                iam.ManagedPolicy.from_aws_managed_policy_name(
                    'service-role/AWSLambdaBasicExecutionRole')
            ])
        run_data_bucket.grant_write(copy_lambda_role)

        callback_role = iam.Role(
            self,
            'CallbackTesLambdaRole',
            assumed_by=iam.ServicePrincipal('lambda.amazonaws.com'),
            managed_policies=[
                iam.ManagedPolicy.from_aws_managed_policy_name(
                    'service-role/AWSLambdaBasicExecutionRole'),
                iam.ManagedPolicy.from_aws_managed_policy_name(
                    'AWSStepFunctionsFullAccess')
            ])

        # Lambda function to call back and complete SFN async tasks
        lmbda.Function(self,
                       'CallbackLambda',
                       function_name='callback_iap_tes_lambda_dev',
                       handler='callback.lambda_handler',
                       runtime=lmbda.Runtime.PYTHON_3_7,
                       code=lmbda.Code.from_asset('lambdas'),
                       role=callback_role,
                       timeout=core.Duration.seconds(20))

        samplesheet_mapper_function = lmbda.Function(
            self,
            'SampleSheetMapperTesLambda',
            function_name='showcase_ss_mapper_iap_tes_lambda_dev',
            handler='launch_tes_task.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'TASK_ID': props['task_id'],
                'TASK_VERSION': 'tvn.0ee81865bf514b7bb7b7ea305c88191f',
                # 'TASK_VERSION': 'tvn.b4735419fbe4455eb2b91960e48921f9',  # echo task
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_LOG_FOLDER': props['gds_log_folder'],
                'IMAGE_NAME': 'umccr/alpine_pandas',
                'IMAGE_TAG': '1.0.1',
                'TES_TASK_NAME': 'SampleSheetMapper'
            })

        bcl_convert_function = lmbda.Function(
            self,
            'BclConvertTesLambda',
            function_name='showcase_bcl_convert_iap_tes_lambda_dev',
            handler='launch_tes_task.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'TASK_ID': props['task_id'],
                'TASK_VERSION': 'tvn.ab3e85f9aaf24890ad169fdab3825c0d',
                # 'TASK_VERSION': 'tvn.b4735419fbe4455eb2b91960e48921f9',  # echo task
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_LOG_FOLDER': props['gds_log_folder'],
                'IMAGE_NAME':
                '699120554104.dkr.ecr.us-east-1.amazonaws.com/public/dragen',
                'IMAGE_TAG': '3.5.2',
                'TES_TASK_NAME': 'BclConvert'
            })

        fastq_mapper_function = lmbda.Function(
            self,
            'FastqMapperTesLambda',
            function_name='showcase_fastq_mapper_iap_tes_lambda_dev',
            handler='launch_tes_task.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'TASK_ID': props['task_id'],
                'TASK_VERSION': 'tvn.f90aa88da2fe490fb6e6366b65abe267',
                # 'TASK_VERSION': 'tvn.b4735419fbe4455eb2b91960e48921f9',  # echo task
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_LOG_FOLDER': props['gds_log_folder'],
                'IMAGE_NAME': 'umccr/alpine_pandas',
                'IMAGE_TAG': '1.0.1',
                'TES_TASK_NAME': 'FastqMapper'
            })

        gather_samples_function = lmbda.Function(
            self,
            'GatherSamplesTesLambda',
            function_name='showcase_gather_samples_iap_tes_lambda_dev',
            handler='gather_samples.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'SSM_PARAM_JWT': props['ssm_param_name']
            })

        dragen_function = lmbda.Function(
            self,
            'DragenTesLambda',
            function_name='showcase_dragen_iap_tes_lambda_dev',
            handler='launch_tes_task.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'TASK_ID': props['task_id'],
                'TASK_VERSION': 'tvn.096b39e90e4443abae0333e23fcabc61',
                # 'TASK_VERSION': 'tvn.b4735419fbe4455eb2b91960e48921f9',  # echo task
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_LOG_FOLDER': props['gds_log_folder'],
                'IMAGE_NAME':
                '699120554104.dkr.ecr.us-east-1.amazonaws.com/public/dragen',
                'IMAGE_TAG': '3.5.2',
                'TES_TASK_NAME': 'Dragen'
            })

        multiqc_function = lmbda.Function(
            self,
            'MultiQcTesLambda',
            function_name='showcase_multiqc_iap_tes_lambda_dev',
            handler='launch_tes_task.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'TASK_ID': props['task_id'],
                'TASK_VERSION': 'tvn.983a0239483d4253a8a0531fa1de0376',
                # 'TASK_VERSION': 'tvn.b4735419fbe4455eb2b91960e48921f9',  # echo task
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_LOG_FOLDER': props['gds_log_folder'],
                'IMAGE_NAME': 'umccr/multiqc_dragen',
                'IMAGE_TAG': '1.1',
                'TES_TASK_NAME': 'MultiQC'
            })

        copy_report_to_s3 = lmbda.Function(
            self,
            'CopyReportToS3Lambda',
            function_name='showcase_copy_report_lambda_dev',
            handler='copy_to_s3.lambda_handler',
            runtime=lmbda.Runtime.PYTHON_3_7,
            code=lmbda.Code.from_asset('lambdas'),
            role=copy_lambda_role,
            timeout=core.Duration.seconds(20),
            environment={
                'IAP_API_BASE_URL': props['iap_api_base_url'],
                'SSM_PARAM_JWT': props['ssm_param_name'],
                'GDS_RUN_VOLUME': props['gds_run_volume'],
                'S3_RUN_BUCKET': props['s3_run_bucket']
            })

        # IAP JWT access token stored in SSM Parameter Store
        secret_value = ssm.StringParameter.from_secure_string_parameter_attributes(
            self,
            "JwtToken",
            parameter_name=props['ssm_param_name'],
            version=props['ssm_param_version'])
        secret_value.grant_read(samplesheet_mapper_function)
        secret_value.grant_read(bcl_convert_function)
        secret_value.grant_read(fastq_mapper_function)
        secret_value.grant_read(gather_samples_function)
        secret_value.grant_read(dragen_function)
        secret_value.grant_read(multiqc_function)
        secret_value.grant_read(copy_report_to_s3)

        # SFN task definitions
        task_samplesheet_mapper = sfn.Task(
            self,
            "SampleSheetMapper",
            task=sfn_tasks.RunLambdaTask(
                samplesheet_mapper_function,
                integration_pattern=sfn.ServiceIntegrationPattern.
                WAIT_FOR_TASK_TOKEN,
                payload={
                    "taskCallbackToken": sfn.Context.task_token,
                    "runId.$": "$.runfolder"
                }),
            result_path="$.guid")

        task_bcl_convert = sfn.Task(
            self,
            "BclConvert",
            task=sfn_tasks.RunLambdaTask(
                bcl_convert_function,
                integration_pattern=sfn.ServiceIntegrationPattern.
                WAIT_FOR_TASK_TOKEN,
                payload={
                    "taskCallbackToken": sfn.Context.task_token,
                    "runId.$": "$.runfolder"
                }),
            result_path="$.guid")

        task_fastq_mapper = sfn.Task(
            self,
            "FastqMapper",
            task=sfn_tasks.RunLambdaTask(
                fastq_mapper_function,
                integration_pattern=sfn.ServiceIntegrationPattern.
                WAIT_FOR_TASK_TOKEN,
                payload={
                    "taskCallbackToken": sfn.Context.task_token,
                    "runId.$": "$.runfolder"
                }),
            result_path="$.guid")

        task_gather_samples = sfn.Task(self,
                                       "GatherSamples",
                                       task=sfn_tasks.InvokeFunction(
                                           gather_samples_function,
                                           payload={"runId.$": "$.runfolder"}),
                                       result_path="$.sample_ids")

        task_dragen = sfn.Task(
            self,
            "DragenTask",
            task=sfn_tasks.RunLambdaTask(
                dragen_function,
                integration_pattern=sfn.ServiceIntegrationPattern.
                WAIT_FOR_TASK_TOKEN,
                payload={
                    "taskCallbackToken": sfn.Context.task_token,
                    "runId.$": "$.runId",
                    "index.$": "$.index",
                    "item.$": "$.item"
                }),
            result_path="$.exit_status")

        task_multiqc = sfn.Task(
            self,
            "MultiQcTask",
            task=sfn_tasks.RunLambdaTask(
                multiqc_function,
                integration_pattern=sfn.ServiceIntegrationPattern.
                WAIT_FOR_TASK_TOKEN,
                payload={
                    "taskCallbackToken": sfn.Context.task_token,
                    "runId.$": "$.runfolder",
                    "samples.$": "$.sample_ids"
                }))

        task_copy_report_to_s3 = sfn.Task(
            self,
            "CopyReportToS3",
            task=sfn_tasks.InvokeFunction(copy_report_to_s3,
                                          payload={"runId.$": "$.runfolder"}),
            result_path="$.copy_report")

        scatter = sfn.Map(self,
                          "Scatter",
                          items_path="$.sample_ids",
                          parameters={
                              "index.$": "$$.Map.Item.Index",
                              "item.$": "$$.Map.Item.Value",
                              "runId.$": "$.runfolder"
                          },
                          result_path="$.mapresults",
                          max_concurrency=20).iterator(task_dragen)

        definition = task_samplesheet_mapper \
            .next(task_bcl_convert) \
            .next(task_fastq_mapper) \
            .next(task_gather_samples) \
            .next(scatter) \
            .next(task_multiqc) \
            .next(task_copy_report_to_s3)

        sfn.StateMachine(
            self,
            "ShowcaseSfnStateMachine",
            definition=definition,
        )
Пример #8
0
    def __init__(self, scope: Construct, construct_id: str, **kwargs) -> None:
        super().__init__(scope, construct_id, **kwargs)

        #the S3 bucket where CloudFront Access Logs will be stored
        cf_access_logs = s3.Bucket(self, "LogBucket")

        #S3 bucket where Athena will put the results
        athena_results = s3.Bucket(self, "AthenaResultsBucket")

        #create an Athena database
        glue_database_name = "serverlessland_database"
        myDatabase = glue.CfnDatabase(
            self,
            id=glue_database_name,
            catalog_id=account,
            database_input=glue.CfnDatabase.DatabaseInputProperty(
                description=f"Glue database '{glue_database_name}'",
                name=glue_database_name,
            )
        )

        #define a table with the structure of CloudFront Logs https://docs.aws.amazon.com/athena/latest/ug/cloudfront-logs.html
        athena_table = glue.CfnTable(self,
            id='cfaccesslogs',
            catalog_id=account,
            database_name=glue_database_name,
            table_input=glue.CfnTable.TableInputProperty(
                name='cf_access_logs',
                description='CloudFront access logs',
                table_type='EXTERNAL_TABLE',
                parameters = {
                    'skip.header.line.count': '2',
                },
                storage_descriptor=glue.CfnTable.StorageDescriptorProperty(
                    location="s3://"+cf_access_logs.bucket_name+"/",
                    input_format='org.apache.hadoop.mapred.TextInputFormat',
                    output_format='org.apache.hadoop.hive.ql.io.HiveIgnoreKeyTextOutputFormat',
                    compressed=False,
                    serde_info=glue.CfnTable.SerdeInfoProperty(
                        serialization_library='org.apache.hadoop.hive.serde2.lazy.LazySimpleSerDe',
                        parameters={
                            'field.delim' : '	'
                        }
                    ),
                    columns=[
                        glue.CfnTable.ColumnProperty(name='date', type='date'),
                        glue.CfnTable.ColumnProperty(name='time', type='string'),
                        glue.CfnTable.ColumnProperty(name='location', type='string'),
                        glue.CfnTable.ColumnProperty(name='bytes', type='bigint'),
                        glue.CfnTable.ColumnProperty(name='request_ip', type='string'),
                        glue.CfnTable.ColumnProperty(name='method', type='string'),
                        glue.CfnTable.ColumnProperty(name='host', type='string'),
                        glue.CfnTable.ColumnProperty(name='uri', type='string'),
                        glue.CfnTable.ColumnProperty(name='status', type='string'),
                        glue.CfnTable.ColumnProperty(name='referer', type='string'),
                        glue.CfnTable.ColumnProperty(name='user_agent', type='string'),
                        glue.CfnTable.ColumnProperty(name='query_string', type='string'),
                        glue.CfnTable.ColumnProperty(name='cookie', type='string'),
                        glue.CfnTable.ColumnProperty(name='result_type', type='string'),
                        glue.CfnTable.ColumnProperty(name='request_id', type='string'),
                        glue.CfnTable.ColumnProperty(name='host_header', type='string'),
                        glue.CfnTable.ColumnProperty(name='request_protocol', type='string'),
                        glue.CfnTable.ColumnProperty(name='request_bytes', type='bigint'),
                        glue.CfnTable.ColumnProperty(name='time_taken', type='float'),
                        glue.CfnTable.ColumnProperty(name='xforwarded_for', type='string'),
                        glue.CfnTable.ColumnProperty(name='ssl_protocol', type='string'),
                        glue.CfnTable.ColumnProperty(name='ssl_cipher', type='string'),
                        glue.CfnTable.ColumnProperty(name='response_result_type', type='string'),
                        glue.CfnTable.ColumnProperty(name='http_version', type='string'),
                        glue.CfnTable.ColumnProperty(name='fle_status', type='string'),
                        glue.CfnTable.ColumnProperty(name='fle_encrypted_fields', type='int'),
                        glue.CfnTable.ColumnProperty(name='c_port', type='int'),
                        glue.CfnTable.ColumnProperty(name='time_to_first_byte', type='float'),
                        glue.CfnTable.ColumnProperty(name='x_edge_detailed_result_type', type='string'),
                        glue.CfnTable.ColumnProperty(name='sc_content_type', type='string'),
                        glue.CfnTable.ColumnProperty(name='sc_content_len', type='string'),
                        glue.CfnTable.ColumnProperty(name='sc_range_start', type='bigint'),
                        glue.CfnTable.ColumnProperty(name='sc_range_end', type='bigint')
                    ]
                ),
            )
        )

        #submit the query and wait for the results
        start_query_execution_job = tasks.AthenaStartQueryExecution(self, "Start Athena Query",
            query_string="SELECT uri FROM cf_access_logs limit 10",
            integration_pattern=sf.IntegrationPattern.RUN_JOB, #executes the command in SYNC mode
            query_execution_context=tasks.QueryExecutionContext(
                database_name=glue_database_name
            ),
            result_configuration=tasks.ResultConfiguration(
                output_location=s3.Location(
                    bucket_name=athena_results.bucket_name,
                    object_key="results"
                )
            )
        )

        #get the results
        get_query_results_job = tasks.AthenaGetQueryResults(self, "Get Query Results",
            query_execution_id=sf.JsonPath.string_at("$.QueryExecution.QueryExecutionId"),
            result_path=sf.JsonPath.string_at("$.GetQueryResults"),
        )

        #prepare the query to see if more results are available (up to 1000 can be retrieved)
        prepare_next_params = sf.Pass(self, "Prepare Next Query Params",
            parameters={
                "QueryExecutionId.$": "$.StartQueryParams.QueryExecutionId",
                "NextToken.$": "$.GetQueryResults.NextToken"
            },
            result_path=sf.JsonPath.string_at("$.StartQueryParams")
        )

        #check to see if more results are available
        has_more_results = sf.Choice(self, "Has More Results?").when(
                    sf.Condition.is_present("$.GetQueryResults.NextToken"),
                    prepare_next_params.next(get_query_results_job)
                ).otherwise(sf.Succeed(self, "Done"))


        #do something with each result
        #here add your own logic
        map = sf.Map(self, "Map State",
            max_concurrency=1,
            input_path=sf.JsonPath.string_at("$.GetQueryResults.ResultSet.Rows[1:]"),
            result_path = sf.JsonPath.DISCARD
        )
        map.iterator(sf.Pass(self, "DoSomething"))


        # Step function to orchestrate Athena query and retrieving the results
        workflow = sf.StateMachine(self, "AthenaQuery",
            definition=start_query_execution_job.next(get_query_results_job).next(map).next(has_more_results),
            timeout=Duration.minutes(60)
        )

        CfnOutput(self, "Logs",
            value=cf_access_logs.bucket_name, export_name='LogsBucket')

        CfnOutput(self, "SFName",
            value=workflow.state_machine_name, export_name='SFName')

        CfnOutput(self, "SFArn",
            value = workflow.state_machine_arn,
            export_name = 'StepFunctionArn',
            description = 'Step Function arn')
    def __init__(self, scope: Construct, id: str, log_level: CfnParameter):
        super().__init__(scope, id)
        self._bundling = {}
        self.log_level = log_level.value_as_string
        self.source_path = Path(__file__).parent.parent.parent.parent
        self.topic = None
        self.subscription = None
        self.functions: Dict[Function] = {}
        self.policies = Policies(self)
        self.create_functions()

        # step function steps
        check_error = sfn.Choice(self, "Check-Error")
        notify_failed = tasks.LambdaInvoke(
            self,
            "Notify-Failed",
            lambda_function=self.functions["SNS"],
            payload_response_only=True,
            retry_on_service_exceptions=True,
            result_path=None,
        )
        notify_failed.next(sfn.Fail(self, "FailureState"))

        create_dataset_group = tasks.LambdaInvoke(
            self,
            "Create-DatasetGroup",
            lambda_function=self.functions["CreateDatasetGroup"],
            result_path="$.DatasetGroupNames",
            payload_response_only=True,
            retry_on_service_exceptions=True,
        )
        create_dataset_group.add_retry(backoff_rate=1.05,
                                       interval=Duration.seconds(5),
                                       errors=["ResourcePending"])
        create_dataset_group.add_catch(notify_failed,
                                       errors=["ResourceFailed"],
                                       result_path="$.serviceError")
        create_dataset_group.add_catch(notify_failed,
                                       errors=["States.ALL"],
                                       result_path="$.statesError")

        import_data = tasks.LambdaInvoke(
            self,
            "Import-Data",
            lambda_function=self.functions["CreateDatasetImportJob"],
            result_path="$.DatasetImportJobArn",
            payload_response_only=True,
            retry_on_service_exceptions=True,
        )
        import_data.add_retry(
            backoff_rate=1.05,
            interval=Duration.seconds(5),
            max_attempts=100,
            errors=["ResourcePending"],
        )
        import_data.add_catch(notify_failed,
                              errors=["ResourceFailed"],
                              result_path="$.serviceError")
        import_data.add_catch(notify_failed,
                              errors=["States.ALL"],
                              result_path="$.statesError")

        update_not_required = sfn.Succeed(self, "Update-Not-Required")
        notify_success = tasks.LambdaInvoke(
            self,
            "Notify-Success",
            lambda_function=self.functions["SNS"],
            payload_response_only=True,
            retry_on_service_exceptions=True,
            result_path=None,
        )

        notify_prediction_failed = tasks.LambdaInvoke(
            self,
            "Notify-Prediction-Failed",
            lambda_function=self.functions["SNS"],
            payload_response_only=True,
            retry_on_service_exceptions=True,
            result_path=None,
        )
        notify_prediction_failed.next(sfn.Fail(self, "Prediction-Failed"))

        create_predictor = tasks.LambdaInvoke(
            self,
            "Create-Predictor",
            lambda_function=self.functions["CreatePredictor"],
            result_path="$.PredictorArn",
            payload_response_only=True,
            retry_on_service_exceptions=True,
        )
        create_predictor.add_retry(
            backoff_rate=1.05,
            interval=Duration.seconds(5),
            max_attempts=100,
            errors=["ResourcePending", "DatasetsImporting"],
        )
        create_predictor.add_catch(
            notify_prediction_failed,
            errors=["ResourceFailed"],
            result_path="$.serviceError",
        )
        create_predictor.add_catch(notify_prediction_failed,
                                   errors=["States.ALL"],
                                   result_path="$.statesError")
        create_predictor.add_catch(update_not_required,
                                   errors=["NotMostRecentUpdate"])

        create_forecast = tasks.LambdaInvoke(
            self,
            "Create-Forecast",
            lambda_function=self.functions["CreateForecast"],
            result_path="$.ForecastArn",
            payload_response_only=True,
            retry_on_service_exceptions=True,
        )
        create_forecast.add_retry(
            backoff_rate=1.05,
            interval=Duration.seconds(5),
            max_attempts=100,
            errors=["ResourcePending"],
        )
        create_forecast.add_catch(
            notify_prediction_failed,
            errors=["ResourceFailed"],
            result_path="$.serviceError",
        )
        create_forecast.add_catch(notify_prediction_failed,
                                  errors=["States.ALL"],
                                  result_path="$.statesError")

        export_forecast = tasks.LambdaInvoke(
            self,
            "Export-Forecast",
            lambda_function=self.functions["PrepareForecastExport"],
            result_path="$.ExportTableName",
            payload_response_only=True,
            retry_on_service_exceptions=True,
        )
        export_forecast.add_catch(
            notify_prediction_failed,
            errors=["ResourceFailed"],
            result_path="$.serviceError",
        )
        export_forecast.add_catch(notify_prediction_failed,
                                  errors=["States.ALL"],
                                  result_path="$.statesError")

        create_forecasts = sfn.Map(
            self,
            "Create-Forecasts",
            items_path="$.DatasetGroupNames",
            parameters={
                "bucket.$": "$.bucket",
                "dataset_file.$": "$.dataset_file",
                "dataset_group_name.$": "$$.Map.Item.Value",
                "config.$": "$.config",
            },
        )

        # step function definition
        definition = (check_error.when(
            sfn.Condition.is_present("$.serviceError"),
            notify_failed).otherwise(create_dataset_group).afterwards().next(
                import_data).next(
                    create_forecasts.iterator(
                        create_predictor.next(create_forecast).next(
                            export_forecast).next(notify_success))))

        self.state_machine = sfn.StateMachine(self,
                                              "DeployStateMachine",
                                              definition=definition)
    def create_state_machine(self, services):

        task_pngextract = aws_stepfunctions_tasks.LambdaInvoke(
            self, "PDF. Conver to PNGs",
            lambda_function = services["lambda"]["pngextract"],
            payload_response_only=True,
            result_path = "$.image_keys"
        )

        task_wrapup = aws_stepfunctions_tasks.LambdaInvoke(
            self, "Wrapup and Clean",
            lambda_function = services["lambda"]["wrapup"]
        )

        iterate_sqs_to_textract = aws_stepfunctions_tasks.SqsSendMessage(
            self, "Perform Textract and A2I",
            queue=services["textract_sqs"], 
            message_body = aws_stepfunctions.TaskInput.from_object({
                "token": aws_stepfunctions.Context.task_token,
                "id.$": "$.id",
                "bucket.$": "$.bucket",
                "key.$": "$.key",
                "wip_key.$": "$.wip_key"
            }),
            delay= None,
            integration_pattern=aws_stepfunctions.ServiceIntegrationPattern.WAIT_FOR_TASK_TOKEN
        )

        process_map = aws_stepfunctions.Map(
            self, "Process_Map",
            items_path = "$.image_keys",
            result_path="DISCARD",
            parameters = {
                "id.$": "$.id",
                "bucket.$": "$.bucket",
                "key.$": "$.key",
                "wip_key.$": "$$.Map.Item.Value"
            }
        ).iterator(iterate_sqs_to_textract)
        
        choice_pass = aws_stepfunctions.Pass(
            self,
            "Image. Passing.",
            result=aws_stepfunctions.Result.from_array(["single_image"]),
            result_path="$.image_keys"
        )

        pdf_or_image_choice = aws_stepfunctions.Choice(self, "PDF or Image?")
        pdf_or_image_choice.when(aws_stepfunctions.Condition.string_equals("$.extension", "pdf"), task_pngextract)
        pdf_or_image_choice.when(aws_stepfunctions.Condition.string_equals("$.extension", "png"), choice_pass)
        pdf_or_image_choice.when(aws_stepfunctions.Condition.string_equals("$.extension", "jpg"), choice_pass)

        # Creates the Step Functions
        multipagepdfa2i_sf = aws_stepfunctions.StateMachine(
            scope = self, 
            id = "multipagepdfa2i_stepfunction",
            state_machine_name = "multipagepdfa2i_stepfunction",
            definition=pdf_or_image_choice.afterwards().next(process_map).next(task_wrapup)
        )

        return multipagepdfa2i_sf