Example #1
0
 def error_callback(carrier, error, callback):
     """
     A wrapper for for use when returning errors to the calling Task.
     The wrapper is mainly to provide OpenTracing boilerplate.
     The carrier should be provided in the case where we don't have the
     original span relating to the request. In that case we create a new
     span to represent this error from the information provided in
     carrier, which should be an OpenTracing Carrier. If we do have the
     original span then carrier should be None, and the span should be 
     managed in the code calling this function.
     We pass in the callback to be used by this function as it may be
     called from something asynchronous like a timeout, where the
     callback may have been stored in pending_requests rather than using
     the callback wrapped in the closure of execute_task.
     """
     scope = None
     if carrier:
         scope = opentracing.tracer.start_active_span(
             operation_name="Task",
             child_of=span_context("text_map", carrier, self.logger),
         )
     else:
         scope = opentracing.tracer.scope_manager.active
     
     with scope:
         opentracing.tracer.active_span.set_tag("error", True)
         opentracing.tracer.active_span.log_kv(
             {
                 "event": error["errorType"],
                 "message": error["errorMessage"],
             }
         )
         if callable(callback):
             callback(error)
Example #2
0
    def handler(self, message):
        print(self.getName() + " working")
        print(message)

        with opentracing.tracer.start_active_span(
                operation_name=self.getName(),
                child_of=span_context("text_map", message.properties,
                                      self.logger),
                tags={
                    "component": "workers",
                    "message_bus.destination": self.getName(),
                    "span.kind": "consumer",
                    "peer.address": "amqp://localhost:5672"
                }) as scope:
            # Create simple reply. In a real processor **** DO WORK HERE ****
            reply = {"reply": self.getName() + " reply"}
            """
            Create the response message by reusing the request note that this
            approach retains the correlation_id, which is necessary. If a fresh
            Message instance is created we would need to get the correlation_id
            from the request Message and use that value in the response message.
            """
            """
            Start an OpenTracing trace for the rpcmessage response.
            https://opentracing.io/guides/python/tracers/ standard tags are from
            https://opentracing.io/specification/conventions/
            """
            with opentracing.tracer.start_active_span(
                    operation_name=self.getName(),
                    child_of=opentracing.tracer.active_span,
                    tags={
                        "component": "workers",
                        "message_bus.destination": message.reply_to,
                        "span.kind": "producer",
                        "peer.address": "amqp://localhost:5672"
                    }) as scope:
                message.properties = inject_span("text_map", scope.span,
                                                 self.logger)
                message.subject = message.reply_to
                message.reply_to = None
                message.body = json.dumps(reply)
                self.producer.send(message)
                message.acknowledge()  # Acknowledges the original request
            def aws_api_StartExecution():
                """
                https://docs.aws.amazon.com/step-functions/latest/apireference/API_StartExecution.html
                """
                # print(params)
                state_machine_arn = params.get("stateMachineArn")
                if not state_machine_arn:
                    self.logger.warning(
                        "RestAPI StartExecution: stateMachineArn must be specified"
                    )
                    return aws_error("MissingRequiredParameter"), 400

                if not valid_state_machine_arn(state_machine_arn):
                    self.logger.warning(
                        "RestAPI StartExecution: {} is an invalid State Machine ARN"
                        .format(state_machine_arn))
                    return aws_error("InvalidArn"), 400
                """
                If name isn't provided create one from a UUID. TODO names should
                be unique within a 90 day period, at the moment there is no code
                to check for uniqueness of provided names so client code that
                doesn't honour this may currently succeed in this implementation
                but fail if calling real AWS StepFunctions.
                """
                name = params.get("name", str(uuid.uuid4()))
                if not valid_name(name):
                    self.logger.warning(
                        "RestAPI StartExecution: {} is an invalid name".format(
                            name))
                    return aws_error("InvalidName"), 400

                input = params.get("input", "{}")
                """
                First check if the input length has exceeded the 262144 character
                quota described in Stepfunction Quotas page.
                https://docs.aws.amazon.com/step-functions/latest/dg/limits.html
                """
                if len(input) > MAX_DATA_LENGTH:
                    self.logger.error(
                        "RestAPI StartExecution: input size for execution '{}' exceeds "
                        "the maximum number of characters service limit.".
                        format(name))
                    return aws_error("InvalidExecutionInput"), 400

                try:
                    input = json.loads(input)
                except TypeError as e:
                    self.logger.error(
                        "RestAPI StartExecution: Invalid input, {}".format(e))
                    return aws_error("InvalidExecutionInput"), 400
                except ValueError as e:
                    self.logger.error(
                        "RestAPI StartExecution: input {} does not contain valid JSON"
                        .format(input))
                    return aws_error("InvalidExecutionInput"), 400

                # Look up stateMachineArn
                state_machine = self.asl_store.get_cached_view(
                    state_machine_arn)
                if not state_machine:
                    self.logger.info(
                        "RestAPI StartExecution: State Machine {} does not exist"
                        .format(state_machine_arn))
                    return aws_error("StateMachineDoesNotExist"), 400

                # Form executionArn from stateMachineArn and name
                arn = parse_arn(state_machine_arn)
                execution_arn = create_arn(
                    service="states",
                    region=arn.get("region", self.region),
                    account=arn["account"],
                    resource_type="execution",
                    resource=arn["resource"] + ":" + name,
                )

                with opentracing.tracer.start_active_span(
                        operation_name="StartExecution:ExecutionLaunching",
                        child_of=span_context("http_headers", request.headers,
                                              self.logger),
                        tags={
                            "component": "rest_api",
                            "execution_arn": execution_arn
                        }) as scope:
                    """
                    The application context is described in the AWS documentation:
                    https://docs.aws.amazon.com/step-functions/latest/dg/input-output-contextobject.html
                    """
                    # https://stackoverflow.com/questions/8556398/generate-rfc-3339-timestamp-in-python
                    start_time = datetime.now(
                        timezone.utc).astimezone().isoformat()
                    context = {
                        "Tracer": inject_span("text_map", scope.span,
                                              self.logger),
                        "Execution": {
                            "Id": execution_arn,
                            "Input": input,
                            "Name": name,
                            "RoleArn": state_machine.get("roleArn"),
                            "StartTime": start_time,
                        },
                        "State": {
                            "EnteredTime": start_time,
                            "Name": ""
                        },  # Start state
                        "StateMachine": {
                            "Id": state_machine_arn,
                            "Name": state_machine.get("name"),
                        },
                    }

                    event = {"data": input, "context": context}
                    """
                    threadsafe=True is important here as the RestAPI runs in a
                    different thread to the main event_dispatcher loop.
                    """
                    try:
                        self.event_dispatcher.publish(event,
                                                      threadsafe=True,
                                                      start_execution=True)
                    except:
                        message = (
                            "RestAPI StartExecution: Internal messaging "
                            "error, start message could not be published.")
                        self.logger.error(message)
                        return aws_error("InternalError", message), 500

                    resp = {
                        "executionArn": execution_arn,
                        "startDate": time.time()
                    }

                    return jsonify(resp), 200
Example #4
0
        def asl_service_states_startExecution():
            """
            Service Integration to stepfunctions. Initially this is limited to
            integrating with stepfunctions running on this local ASL Workflow
            Engine, however it should be possible to integrate with *real* AWS
            stepfunctions too in due course.

            Up until September 2019 real AWS stepfunctions didn't have a direct
            Service Integration to other stepfunctions and required a lambda to
            do this, so be aware that many online examples likely illustrate the
            use of a lambda as a proxy to the child stepfunction but this link
            documents the new more direct service integration:            
            https://docs.aws.amazon.com/step-functions/latest/dg/connect-stepfunctions.html

            The resource ARN should be of the form:
            arn:aws:states:region:account-id:states:startExecution

            The Task must have a Parameters field of the form:

            "Parameters": {
                "Input": "ChildStepFunctionInput",
                "StateMachineArn": "ChildStateMachineArn",
                "Name": "OptionalExecutionName"
            },

            if the execution Name is not specified in the Parameters a UUID will
            be assigned by the service.

            TODO: At the moment this integration only supports the direct
            request/response stepfunction integration NOT the "run a job" (.sync)
            or "wait for callback" (.waitForTaskToken) forms so at the moment
            we can't yet wait for the child stepfunction to complete or wait
            for a callback. To achieve this we *probably* need to poll
            ListExecutions with statusFilter="SUCCEEDED" comparing the response
            with our executionArn, which is mostly fairly straightforward but
            becomes more involved in a clustered environment. The callback
            integration might be slightly easier to implement as it shouldn't
            need any polling but one oddity is that there doesn't seem to be
            a stepfunctions integration for SendTaskSuccess (or SendTaskFailure
            or SendTaskHeartbeat) as those are the APIs that relate to
            triggering or cancelling a callback by implication there is no
            direct mechanism yet for a child stepfunction to actually do the
            callback other than proxying via a lambda. My guess is that this is
            an accidental omission that was missed when implementing the
            stepfunctions integration and will be added IDC.
            """
            execution_name = parameters.get("Name", str(uuid.uuid4()))

            state_machine_arn = parameters.get("StateMachineArn")
            if not state_machine_arn:
                message = "TaskDispatcher asl_service_states_startExecution: " \
                          "StateMachineArn must be specified"
                error = {"errorType": "MissingRequiredParameter", "errorMessage": message}
                error_callback(context.get("Tracer", {}), error, callback)
                return

            arn = parse_arn(state_machine_arn)
            state_machine_name = arn["resource"]

            execution_arn = create_arn(
                service="states",
                region=arn.get("region", "local"),
                account=arn["account"],
                resource_type="execution",
                resource=state_machine_name + ":" + execution_name,
            )

            # Look up stateMachineArn
            state_machine = self.state_engine.asl_store.get_cached_view(state_machine_arn)
            if not state_machine:
                message = "TaskDispatcher asl_service_states_startExecution: " \
                          "State Machine {} does not exist".format(
                    state_machine_arn
                )
                error = {"errorType": "StateMachineDoesNotExist", "errorMessage": message}
                error_callback(context.get("Tracer", {}), error, callback)
                return

            """
            Start an OpenTracing trace for the child StartExecution request.
            https://opentracing.io/guides/python/tracers/ standard tags are from
            https://opentracing.io/specification/conventions/
            """
            with opentracing.tracer.start_active_span(
                operation_name="Task",
                child_of=span_context("text_map", context.get("Tracer", {}), self.logger),
                tags={
                    "component": "task_dispatcher",
                    "resource_arn": resource_arn,
                    "execution_arn": context["Execution"]["Id"],
                    "child_execution_arn": execution_arn,
                }
            ) as scope:
                # Create the execution context and the event to publish to launch
                # the requested new state machine execution.
                # https://stackoverflow.com/questions/8556398/generate-rfc-3339-timestamp-in-python
                start_time = datetime.now(timezone.utc).astimezone().isoformat()
                child_context = {
                    "Tracer": inject_span("text_map", scope.span, self.logger),
                    "Execution": {
                        "Id": execution_arn,
                        "Input": parameters.get("Input", {}),
                        "Name": execution_name,
                        "RoleArn": state_machine.get("roleArn"),
                        "StartTime": start_time,
                    },
                    "State": {"EnteredTime": start_time, "Name": ""},  # Start state
                    "StateMachine": {"Id": state_machine_arn, "Name": state_machine_name},
                }

                event = {"data": parameters.get("Input", {}), "context": child_context}
                self.state_engine.event_dispatcher.publish(
                    event, start_execution=True
                )

                """
                The Stepfunction metrics specified in the docs:
                https://docs.aws.amazon.com/step-functions/latest/dg/procedure-cw-metrics.html
                describe Service Integration Metrics, which would include child
                Stepfunctions. The "dimension" for these, however is the
                resource ARN of the integrated service. In practice that is
                less than ideal as it only really describes the startExecution
                service so the ARN is the same irrespective of the state machine.
                """
                if self.task_metrics:
                    self.task_metrics["ServiceIntegrationsScheduled"].inc(
                        {"ServiceIntegrationResourceArn": resource_arn}
                    )

                result = {"executionArn": execution_arn, "startDate": time.time()}

                """
                For now we only support "fire and forget" launching of child
                Stepfunctions and *not* synchronous ones that wait for the
                child Stepfunction to complete nor waitForTaskToken style
                callbacks. As it it fire and forget the ServiceIntegration
                is deemed successful if the request is successfully dispatched.
                TODO the other ServiceIntegration metrics are really only useful
                for synchronous (not fire and forget) execution invocations.
                """
                if self.task_metrics:
                    self.task_metrics["ServiceIntegrationsSucceeded"].inc(
                        {"ServiceIntegrationResourceArn": resource_arn}
                    )

                callback(result)
Example #5
0
        def asl_service_rpcmessage():
            """
            Publish message to the required rpcmessage worker resource. The
            message body is a JSON string containing the "effective parameters"
            passed from the StateEngine and the message subject is the resource
            name, which in turn maps to the name of the queue that the resource
            has set up to listen on. In order to ensure that responses from
            Task resources return to the correct calling Task the message is
            populated with a reply to address representing this this workflow
            engine's reply_to queue as well as a correlation ID to ensure that
            response messages can be correctly tracked irrespective of the order
            that they are returned - as that might be quite different from the
            order that the requests were sent.

            Setting content_type to application/json isn't necessary for correct
            operation, however it is the correct thing to do:
            https://www.ietf.org/rfc/rfc4627.txt.
            """
            # print("asl_service_rpcmessage")
            # print(service)
            # print(resource_type)
            # print(resource)
            # print(parameters)
            # TODO deal with the case of delivering to a different broker.

            # Associate response callback with this request via correlation ID
            correlation_id = str(uuid.uuid4())

            """
            Create a timeout in case the rpcmessage invocation fails.
            The timeout sends an error response to the calling Task state and
            deletes the pending request.
            """
            def on_timeout():
                # Do lookup in case timeout fires after a successful response.
                request = self.pending_requests.get(correlation_id)
                if request:
                    del self.pending_requests[correlation_id]
                    callback, resource_arn, sched_time, timeout_id, task_span = request
                    error = {
                        "errorType": "States.Timeout",
                        "errorMessage": "State or Execution ran for longer " +
                            "than the specified TimeoutSeconds value",
                    }
                    with opentracing.tracer.scope_manager.activate(
                        span=task_span,
                        finish_on_close=True
                    ) as scope:
                        error_callback(None, error, callback)

            """
            Start an OpenTracing trace for the rpcmessage request.
            https://opentracing.io/guides/python/tracers/ standard tags are from
            https://opentracing.io/specification/conventions/
            """
            with opentracing.tracer.start_active_span(
                operation_name="Task",
                child_of=span_context("text_map", context.get("Tracer", {}), self.logger),
                tags={
                    "component": "task_dispatcher",
                    "resource_arn": resource_arn,
                    "message_bus.destination": resource,
                    "span.kind": "producer",
                    "peer.address": self.peer_address,
                    "execution_arn": context["Execution"]["Id"]
                },
                finish_on_close=False
            ) as scope:
                """
                We also pass the span to pending_requests so we can use
                it when receiving a response, and in case of a timeout.
                """
                carrier = inject_span("text_map", scope.span, self.logger)
                message = Message(
                    json.dumps(parameters),
                    properties=carrier,
                    content_type="application/json",
                    subject=resource,
                    reply_to=self.reply_to.name,
                    correlation_id=correlation_id,
                            # Give the RPC Message a TTL equivalent to the ASL
                            # Task State (or Execution) timeout period. Both are ms.
                    expiration=timeout,
                )

                timeout_id = self.state_engine.event_dispatcher.set_timeout(
                    on_timeout, timeout
                )

                """
                The service response message is handled by handle_rpcmessage_response()
                If the response occurs before the timeout expires the timeout
                should be cancelled, so we store the timeout_id as well as the
                required callback in the dict keyed by correlation_id.
                As mentioned above we also store the OpenTracing span so that if a 
                timeout occurs we can raise an error on the request span.
                """
                self.pending_requests[correlation_id] = (
                    callback, resource_arn, time.time() * 1000, timeout_id, scope.span
                )
                self.producer.send(message)

                if self.task_metrics:
                    self.task_metrics["LambdaFunctionsScheduled"].inc(
                        {"LambdaFunctionArn": resource_arn}
                    )
Example #6
0
    def handle_rpcmessage_response(self, message):
        """
        This is a message listener receiving messages from the reply_to queue
        for this workflow engine instance.
        TODO cater for the case where requests are sent but responses never
        arrive, this scenario will cause self.pending_requests to "leak" as
        correlation_id keys get added but not removed. This situation should be
        improved as we add code to handle Task state "rainy day" scenarios such
        as Timeouts etc. so park for now, but something to be aware of.
        """
        correlation_id = message.correlation_id
        request = self.pending_requests.get(correlation_id)
        if request:
            try:
                del self.pending_requests[correlation_id]
                callback, resource_arn, sched_time, timeout_id, rpcmessage_task_span = request
                with opentracing.tracer.scope_manager.activate(
                    span=rpcmessage_task_span,
                    finish_on_close=True
                ) as scope:
                    # Cancel the timeout previously set for this request.
                    self.state_engine.event_dispatcher.clear_timeout(timeout_id)
                    if callable(callback):
                        message_body = message.body
                        """
                        First check if the response has exceeded the 262144
                        character quota described in Stepfunction Quotas page.
                        https://docs.aws.amazon.com/step-functions/latest/dg/limits.html
                        We do the test here as we have the raw JSON string handy.
                        """
                        if len(message_body) > MAX_DATA_LENGTH:
                            result = {"errorType": "States.DataLimitExceeded"}
                        else:
                            result = json.loads(message_body.decode("utf8"))
                        error_type = result.get("errorType")
                        if error_type:
                            opentracing.tracer.active_span.set_tag("error", True)
                            opentracing.tracer.active_span.log_kv(
                                {
                                    "event": error_type,
                                    "message": result.get("errorMessage", ""),
                                }
                            )
                            if self.task_metrics:
                                """
                                When Lambda times out it returns JSON including
                                "errorType": "TimeoutError"
                                See the following for an example illustration
                                https://stackoverflow.com/questions/65036533/my-lambda-is-throwing-a-invoke-error-timeout
                                rpcmessage processors should follow the same
                                convention so we can trap processor timeout
                                errors and provide metrics on these.
                                """
                                if error_type == "TimeoutError":
                                    self.task_metrics["LambdaFunctionsTimedOut"].inc(
                                        {"LambdaFunctionArn": resource_arn}
                                    )

                                self.task_metrics["LambdaFunctionsFailed"].inc(
                                    {"LambdaFunctionArn": resource_arn}
                                )
                        else:
                            if self.task_metrics:
                                self.task_metrics["LambdaFunctionsSucceeded"].inc(
                                    {"LambdaFunctionArn": resource_arn}
                                )

                        if self.task_metrics:
                            duration = (time.time()  * 1000.0) - sched_time
                            self.task_metrics["LambdaFunctionTime"].observe(
                                {"LambdaFunctionArn": resource_arn}, duration
                            )

                        callback(result)
            except ValueError as e:
                self.logger.error(
                    "Response {} does not contain valid JSON".format(message.body)
                )
        else:
            with opentracing.tracer.start_active_span(
                operation_name="Task",
                child_of=span_context("text_map", message.properties, self.logger),
                tags={
                    "component": "task_dispatcher",
                    "message_bus.destination": self.reply_to.name,
                    "span.kind": "consumer",
                    "peer.address": self.peer_address
                }
            ) as scope:
                self.logger.info("Response {} has no matching requestor".format(message))
                scope.span.set_tag("error", True)
                scope.span.log_kv(
                    {
                        "event": "No matching requestor",
                        "message": "Response has no matching requestor",
                    }
                )


        message.acknowledge(multiple=False)