def merge_args_and_kwargs(task_result: Message,
                              task_args_and_kwargs: ArgsAndKwargs) -> Any:
        """
        Merges args & kwargs passed explicity to a task entry in a pipeline with results from the previous task.
        If there are no args & kwargs to merge then the result of the previous task is returned unchanged

        :param task_result: the request payload proto
        :param task_args_and_kwargs: the args and kwargs from the task entry to merge
        :return: Any
        """
        args_to_merge = task_args_and_kwargs.args
        kwargs = task_args_and_kwargs.kwargs

        if not any(args_to_merge.items) and not any(kwargs.items):
            return pack_any(task_result)

        merged_args = TupleOfAny()
        # task result may be a single proto in which case we have to wrap into TupleOfAny to be able to extend
        if not isinstance(task_result, TupleOfAny):
            merged_args.items.append(pack_any(task_result))
        else:
            merged_args.items.extend(task_result.items)

        merged_args.items.extend(args_to_merge.items)

        return pack_any(ArgsAndKwargs(args=merged_args, kwargs=kwargs))
 def serialise_result(self, task_result: TaskResult, result, state):
     """
     Serialises the result of a task invocation into a TaskResult
     
     :param task_result: the TaskResult
     :param result: task result
     :param state: task state
     """
     task_result.result.CopyFrom(
         pack_any(_convert_to_proto(result, self._protobuf_converters)))
     task_result.state.CopyFrom(
         pack_any(_convert_to_proto(state, self._protobuf_converters)))
    def serialise_args_and_kwargs(self, args, kwargs) -> Any:
        """
        Serialises Python args and kwargs into protobuf
        
        If there is a single arg and no kwargs and the arg is already a protobuf it returns
        that instance packed inside an Any.  Otherwise it returns an ArgsAndKwargs packed in an Any.

        :param args: the Python args
        :param kwargs: the Python kwargs
        :return: protobuf Any
        """

        # if kwargs are empty and this is a single protobuf arguments then
        # send in simple format i.e. fn(protobuf) -> protobuf as opposed to fn(*args, **kwargs) -> (*results,)
        # so as to aid calling flink functions written in other frameworks that might not understand
        # how to deal with the concept of keyword arguments or tuples of arguments
        if isinstance(args, Message) and not kwargs:
            request = args
        else:
            args = args if _is_tuple(args) else (args, )
            request = ArgsAndKwargs()
            request.args.CopyFrom(
                _convert_to_proto(args, self._protobuf_converters))
            request.kwargs.CopyFrom(
                _convert_to_proto(kwargs, self._protobuf_converters))

        return pack_any(request)
    async def aggregate(self, context: TaskContext, group: Group):
        task_results = context.pipeline_state.task_results  # type: TaskResults

        aggregated_results, aggregated_states, aggregated_errors = [], [], []
        stack = deque([(group, aggregated_results)])  # FIFO, errors are flat not nested so don't get stacked

        while len(stack) > 0:
            group, results = stack.popleft()
            
            for pipeline in group:
                last_entry = pipeline[-1]

                if isinstance(last_entry, Group):
                    stack_results = []
                    results.append(stack_results)
                    stack.append((last_entry, stack_results))
                else:
                    proto = task_results.by_id[pipeline[-1].task_id]  # Any
                    result, state, error = await self._load_result(group, proto)

                    # We don't need the individual task results anymore so remove to save space / reduce message size
                    del task_results.by_id[pipeline[-1].task_id] 

                    results.append(result)
                    
                    aggregated_states.append(state)  # states are flat not nested so don't get stacked
                    aggregated_errors.append(error)  # errors are flat not nested so don't get stacked

        # cleanup storage
        await self._cleanup_storage(group)

        aggregated_errors = [error for error in aggregated_errors if error is not None]
        aggregated_state = self._aggregate_state(aggregated_states)

        if any(aggregated_errors):

            serialised_state = self._serialiser.to_proto(aggregated_state)

            task_exception = TaskException(            
                id=group.group_id,
                type=f'__aggregate.error',
                exception_type='statefun_tasks.AggregatedError',
                exception_message='|'.join([f'{e.id}, {e.type}, {e.exception_message}' for e in aggregated_errors]),
                stacktrace='|'.join([f'{e.id}, {e.stacktrace}' for e in aggregated_errors]),
                state=pack_any(serialised_state))
            
            return task_exception
            
        else:
        
            task_result = TaskResult(id=group.group_id)
            self._serialiser.serialise_result(task_result, aggregated_results, aggregated_state)

            return task_result
Exemple #5
0
def _create_task_result(task_input, result=None):
    if isinstance(task_input, TaskActionRequest):
        task_result = TaskActionResult(id=task_input.id,
                                       action=task_input.action)
    else:
        task_result = TaskResult(id=task_input.id,
                                 type=f'{task_input.type}.result')

    if result is not None:
        task_result.result.CopyFrom(pack_any(result))

    return task_result
    def send_egress_message(self, topic, value):
        """
        Sends a message to an egress topic

        :param topic: the topic name
        :param value: the message to send
        """
        proto_bytes = pack_any(value).SerializeToString()
        message = kafka_egress_message(typename=self._egress_type_name,
                                       topic=topic,
                                       value=proto_bytes)
        self._context.send_egress(message)
    async def _save_result(self, group, task_id, proto, task_results, size_of_state):
        saved_to_storage = False

        if self._storage is not None and size_of_state >= self._storage.threshold:
            saved_to_storage = await self._try_save_to_store([group.group_id, task_id], proto)

        if saved_to_storage: 
            # record ptr to state
            task_results.by_id[task_id].CopyFrom(pack_any(StringValue(value=task_id)))
        else:
            # record data to state
            task_results.by_id[task_id].CopyFrom(proto)
    async def add_result(self, context: TaskContext, group: Group, task_id, task_result_or_exception):
        task_results = context.pipeline_state.task_results  # type: TaskResults

        failed = isinstance(task_result_or_exception, TaskException)
        last_task = self._graph.get_last_task_in_chain(task_id)
        packed = pack_any(task_result_or_exception)
        
        if task_id == last_task.task_id:
            # if we are the last task in this chain in the group then record this result so we can aggregate laster
            await self._save_result(group, task_id, packed, task_results, context.pipeline_state_size)

        elif failed:
            # additionally propagate the error onto the last stage of this chain
            await self._save_result(group, last_task.task_id, packed, task_results, context.pipeline_state_size)
    def serialise_request(self,
                          task_request: TaskRequest,
                          request: Any,
                          state=None,
                          retry_policy=None):
        """
        Serialises args, kwargs and optional state into a TaskRequest
        
        :param task_request: the TaskRequest
        :param request: request (proto format)
        :param optional state: task state
        :param optional retry_policy: task retry policy
        """
        task_request.request.CopyFrom(request)
        task_request.state.CopyFrom(
            pack_any(_convert_to_proto(state, self._protobuf_converters)))

        if retry_policy:
            task_request.retry_policy.CopyFrom(retry_policy)
 def set_state(self, obj):
     self.task_state.internal_state.CopyFrom(
         pack_any(self._serialiser.to_proto(obj)))