def _run_flink_loop(self, message_arg: Union[TaskRequest, TaskResult, TaskException, TaskActionRequest], target: Address, caller=None): to_function = ToFunction() update_address(to_function.invocation.target, target.namespace, target.type, target.id) invocation = to_function.invocation.invocations.add() if caller: update_address(invocation.caller, caller.namespace, caller.type, caller.id) invocation.argument.Pack(message_arg) self._copy_state_to_invocation(target.namespace, target.type, target.id, to_function) result_bytes = asyncio.get_event_loop().run_until_complete( async_handler(to_function.SerializeToString())) result = self._process_result(to_function, result_bytes) if result.egress_message is not None: return result.egress_message else: outgoing_messages = result.outgoing_messages for outgoing_message in outgoing_messages: message_arg = unpack_any( outgoing_message.argument, [TaskRequest, TaskResult, TaskException]) egress_value = self._run_flink_loop( message_arg=message_arg, target=outgoing_message.target, caller=target) if egress_value: return egress_value
def setup(self, request_bytes): to_function = ToFunction() to_function.ParseFromString(request_bytes) # # setup # target_address = to_function.invocation.target target_function = self.functions.for_type(target_address.namespace, target_address.type) if target_function is None: raise ValueError("Unable to find a function of type ", target_function) # for each state spec defined in target function # if state name is in request -> add to Batch Context # if state name is not in request -> add to missing_state_specs provided_state_values = self.provided_state_values(to_function) missing_state_specs = [] resolved_state_values = {} for state_name, state_spec in target_function.registered_state_specs.items( ): if state_name in provided_state_values: resolved_state_values[state_name] = provided_state_values[ state_name] else: missing_state_specs.append(state_spec) self.batch = to_function.invocation.invocations self.context = BatchContext(target_address, resolved_state_values) self.target_function = target_function if missing_state_specs: self.missing_state_specs = missing_state_specs
def setup(self, request_bytes): to_function = ToFunction() to_function.ParseFromString(request_bytes) # # setup # context = BatchContext(to_function.invocation.target, to_function.invocation.state) target_function = self.functions.for_type(context.address.namespace, context.address.type) if target_function is None: raise ValueError("Unable to find a function of type ", target_function) self.batch = to_function.invocation.invocations self.context = context self.target_function = target_function
class InvocationBuilder(object): """builder for the ToFunction message""" def __init__(self): self.to_function = ToFunction() def with_target(self, ns, type, id): InvocationBuilder.set_address(ns, type, id, self.to_function.invocation.target) return self def with_state(self, name, value=None): state = self.to_function.invocation.state.add() state.state_name = name if value: any = Any() any.Pack(value) state.state_value = any.SerializeToString() return self def with_invocation(self, arg, caller=None): invocation = self.to_function.invocation.invocations.add() if caller: (ns, type, id) = caller InvocationBuilder.set_address(ns, type, id, invocation.caller) invocation.argument.Pack(arg) return self def SerializeToString(self): return self.to_function.SerializeToString() @staticmethod def set_address(namespace, type, id, address): address.namespace = namespace address.type = type address.id = id
def _run_flink_loop(self, message_arg, target: Address, caller=None, egress_result=None): to_function = ToFunction() update_address(to_function.invocation.target, target.namespace, target.type, target.id) invocation = to_function.invocation.invocations.add() if caller: update_address(invocation.caller, caller.namespace, caller.type, caller.id) if isinstance(message_arg, TypedValue): # function calling function invocation.argument.CopyFrom(message_arg) else: flink_type = flink_value_type_for( message_arg ) # ingress protobuf needs to to be wrapped into a TypedValue invocation.argument.CopyFrom( TypedValue(typename=flink_type.typename, has_value=True, value=message_arg.SerializeToString())) self._copy_state_to_invocation(target.namespace, target.type, target.id, to_function) result_bytes = asyncio.get_event_loop().run_until_complete( handler.handle_async(to_function.SerializeToString())) result = self._process_result(to_function, result_bytes) # remember first egress result if result.egress_message is not None and egress_result is None: egress_result = result.egress_message # recurse while we have outgoing messages outgoing_messages = result.outgoing_messages for outgoing_message in outgoing_messages: egress_value = self._run_flink_loop( message_arg=outgoing_message.argument, target=outgoing_message.target, caller=target, egress_result=egress_result) if egress_value: return egress_value return egress_result
async def handle_async( self, request_bytes: typing.Union[str, bytes, bytearray]) -> bytes: # parse pb_to_function = ToFunction() pb_to_function.ParseFromString(request_bytes) # target address pb_target_address = pb_to_function.invocation.target sdk_address = sdk_address_from_pb(pb_target_address) # target stateful function target_fn: StatefulFunction = self.functions.for_typename( sdk_address.typename) if not target_fn: raise ValueError( f"Unable to find a function of type {sdk_address.typename}") # resolve state res = resolve(target_fn.storage_spec, sdk_address.typename, pb_to_function.invocation.state) if res.missing_specs: pb_from_function = collect_failure(res.missing_specs) return pb_from_function.SerializeToString() # invoke the batch ctx = UserFacingContext(sdk_address, res.storage) fun = target_fn.fun pb_batch = pb_to_function.invocation.invocations if target_fn.is_async: for pb_invocation in pb_batch: msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, typed_value=pb_invocation.argument) ctx._caller = sdk_address_from_pb(pb_invocation.caller) # await for an async function to complete. # noinspection PyUnresolvedReferences await fun(ctx, msg) else: for pb_invocation in pb_batch: msg = Message(target_typename=sdk_address.typename, target_id=sdk_address.id, typed_value=pb_invocation.argument) ctx._caller = sdk_address_from_pb(pb_invocation.caller) # we need to call the function directly ¯\_(ツ)_/¯ fun(ctx, msg) # collect the results pb_from_function = collect_success(ctx) return pb_from_function.SerializeToString()
class InvocationBuilder(object): """builder for the ToFunction message""" def __init__(self): self.to_function = ToFunction() def with_target(self, ns, type, id): InvocationBuilder.set_address(ns, type, id, self.to_function.invocation.target) return self def with_state(self, name, value=None): state = self.to_function.invocation.state.add() state.state_name = name if value: state.state_value.CopyFrom(self.to_typed_value_any_state(value)) return self def with_invocation(self, arg, caller=None): invocation = self.to_function.invocation.invocations.add() if caller: (ns, type, id) = caller InvocationBuilder.set_address(ns, type, id, invocation.caller) invocation.argument.CopyFrom(self.to_typed_value(arg)) return self def SerializeToString(self): return self.to_function.SerializeToString() @staticmethod def to_typed_value(proto_msg): any = Any() any.Pack(proto_msg) typed_value = TypedValue() typed_value.typename = any.type_url typed_value.value = any.value return typed_value @staticmethod def to_typed_value_any_state(proto_msg): any = Any() any.Pack(proto_msg) typed_value = TypedValue() typed_value.typename = "type.googleapis.com/google.protobuf.Any" typed_value.value = any.SerializeToString() return typed_value @staticmethod def set_address(namespace, type, id, address): address.namespace = namespace address.type = type address.id = id
def __init__(self): self.to_function = ToFunction()
def __call__(self, request_bytes): request = ToFunction() request.ParseFromString(request_bytes) reply = self.handle_invocation(request) return reply.SerializeToString()