def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) name = common.NameContext(transform_proto.unique_name) serialized_fn = spec.serialized_fn if hasattr(serialized_fn, "key_type"): # keyed operation, need to create the KeyedStateBackend. row_schema = serialized_fn.key_type.row_schema key_row_coder = FlattenRowCoder( [from_proto(f.type) for f in row_schema.fields]) if serialized_fn.HasField('group_window'): if serialized_fn.group_window.is_time_window: window_coder = TimeWindowCoder() else: window_coder = CountWindowCoder() else: window_coder = None keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, window_coder, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) elif internal_operation_cls == datastream_operations.StatefulOperation: key_row_coder = from_type_info_proto(serialized_fn.key_type_info) keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, None, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) else: return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls)
def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) if hasattr(spec.serialized_fn, "key_type"): # keyed operation, need to create the KeyedStateBackend. key_row_coder = from_proto(spec.serialized_fn.key_type) keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, spec.serialized_fn.state_cache_size, spec.serialized_fn.map_state_read_cache_size, spec.serialized_fn.map_state_write_cache_size) return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) else: return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls)
def _create_stateful_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_coders]) key_type_info = spec.serialized_fn.key_type_info key_row_coder = from_type_info_proto(key_type_info.field[0].type) keyed_state_backend = RemoteKeyedStateBackend(factory.state_handler, key_row_coder, 1000, 1000, 1000) return beam_operation_cls(transform_proto.unique_name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend)
def extract_stateful_function(user_defined_function_proto, runtime_context: RuntimeContext, keyed_state_backend: RemoteKeyedStateBackend): func_type = user_defined_function_proto.function_type user_defined_func = pickle.loads(user_defined_function_proto.payload) internal_timer_service = InternalTimerServiceImpl(keyed_state_backend) def state_key_selector(normal_data): return Row(normal_data[0]) def user_key_selector(normal_data): return normal_data[0] def input_selector(normal_data): return normal_data[1] UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction if func_type in (UserDefinedDataStreamFunction.KEYED_PROCESS, UserDefinedDataStreamFunction.KEYED_CO_PROCESS): timer_service = TimerServiceImpl(internal_timer_service) ctx = InternalKeyedProcessFunctionContext(timer_service) on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext( timer_service) process_function = user_defined_func internal_timer_service.set_namespace_serializer( VoidNamespaceSerializer()) def open_func(): if hasattr(process_function, "open"): process_function.open(runtime_context) def close_func(): if hasattr(process_function, "close"): process_function.close() def on_event_time(timestamp: int, key, namespace): keyed_state_backend.set_current_key(key) return _on_timer(TimeDomain.EVENT_TIME, timestamp, key) def on_processing_time(timestamp: int, key, namespace): keyed_state_backend.set_current_key(key) return _on_timer(TimeDomain.PROCESSING_TIME, timestamp, key) def _on_timer(time_domain: TimeDomain, timestamp: int, key): user_current_key = user_key_selector(key) on_timer_ctx.set_timestamp(timestamp) on_timer_ctx.set_current_key(user_current_key) on_timer_ctx.set_time_domain(time_domain) return process_function.on_timer(timestamp, on_timer_ctx) if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS: def process_element(normal_data, timestamp: int): ctx.set_timestamp(timestamp) ctx.set_current_key(user_key_selector(normal_data)) keyed_state_backend.set_current_key( state_key_selector(normal_data)) return process_function.process_element( input_selector(normal_data), ctx) elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS: def process_element(normal_data, timestamp: int): is_left = normal_data[0] if is_left: user_input = normal_data[1] else: user_input = normal_data[2] ctx.set_timestamp(timestamp) on_timer_ctx.set_current_key(user_key_selector(user_input)) keyed_state_backend.set_current_key( state_key_selector(user_input)) if is_left: return process_function.process_element1( input_selector(user_input), ctx) else: return process_function.process_element2( input_selector(user_input), ctx) else: raise Exception("Unsupported func_type: " + str(func_type)) elif func_type == UserDefinedDataStreamFunction.WINDOW: window_operation_descriptor = user_defined_func window_assigner = window_operation_descriptor.assigner window_trigger = window_operation_descriptor.trigger allowed_lateness = window_operation_descriptor.allowed_lateness window_state_descriptor = window_operation_descriptor.window_state_descriptor internal_window_function = window_operation_descriptor.internal_window_function window_serializer = window_operation_descriptor.window_serializer window_coder = window_serializer._get_coder() keyed_state_backend.namespace_coder = window_coder keyed_state_backend._namespace_coder_impl = window_coder.get_impl() window_operator = WindowOperator(window_assigner, keyed_state_backend, user_key_selector, window_state_descriptor, internal_window_function, window_trigger, allowed_lateness) internal_timer_service.set_namespace_serializer(window_serializer) def open_func(): window_operator.open(runtime_context, internal_timer_service) def close_func(): window_operator.close() def process_element(normal_data, timestamp: int): keyed_state_backend.set_current_key( state_key_selector(normal_data)) return window_operator.process_element(input_selector(normal_data), timestamp) def on_event_time(timestamp: int, key, namespace): keyed_state_backend.set_current_key(key) return window_operator.on_event_time(timestamp, key, namespace) def on_processing_time(timestamp: int, key, namespace): keyed_state_backend.set_current_key(key) return window_operator.on_processing_time(timestamp, key, namespace) else: raise Exception("Unsupported function_type: " + str(func_type)) input_handler = RunnerInputHandler(internal_timer_service, process_element) process_element_func = input_handler.process_element timer_handler = TimerHandler(internal_timer_service, on_event_time, on_processing_time, keyed_state_backend._namespace_coder_impl) process_timer_func = timer_handler.process_timer return open_func, close_func, process_element_func, process_timer_func, internal_timer_service
def extract_keyed_stateful_function( user_defined_function_proto, keyed_state_backend: RemoteKeyedStateBackend, runtime_context: RuntimeContext): func_type = user_defined_function_proto.function_type UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction payload = pickle.loads(user_defined_function_proto.payload) internal_timer_service = InternalTimerServiceImpl(keyed_state_backend) def state_key_selector(normal_data): return Row(normal_data[0]) def user_key_selector(normal_data): return normal_data[0] def input_selector(normal_data): return normal_data[1] if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS or \ func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS: timer_service = TimerServiceImpl(internal_timer_service) on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext( timer_service) ctx = InternalKeyedProcessFunctionContext(timer_service) process_function = payload output_factory = RowWithTimerOutputFactory(VoidNamespaceSerializer()) def open_func(): if hasattr(process_function, "open"): process_function.open(runtime_context) def close_func(): if hasattr(process_function, "close"): process_function.close() if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS: def process_element(normal_data, timestamp: int): ctx.set_timestamp(timestamp) user_current_key = user_key_selector(normal_data) ctx.set_current_key(user_current_key) return process_function.process_element( input_selector(normal_data), ctx) def on_event_time(internal_timer: InternalTimerImpl): timestamp = internal_timer.get_timestamp() state_current_key = internal_timer.get_key() user_current_key = user_key_selector(state_current_key) on_timer_ctx.set_current_key(user_current_key) on_timer_ctx.set_time_domain(TimeDomain.EVENT_TIME) return process_function.on_timer(timestamp, on_timer_ctx) def on_processing_time(internal_timer: InternalTimerImpl): timestamp = internal_timer.get_timestamp() state_current_key = internal_timer.get_key() user_current_key = user_key_selector(state_current_key) on_timer_ctx.set_current_key(user_current_key) on_timer_ctx.set_time_domain(TimeDomain.PROCESSING_TIME) return process_function.on_timer(timestamp, on_timer_ctx) input_handler = OneInputRowWithTimerHandler( internal_timer_service, keyed_state_backend, state_key_selector, process_element, on_event_time, on_processing_time, output_factory) process_element_func = input_handler.accept elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS: input_handler = TwoInputRowWithTimerHandler( ctx, on_timer_ctx, timer_service, keyed_state_backend, process_function, output_factory) process_element_func = input_handler.accept else: raise Exception("Unsupported func_type: " + str(func_type)) elif func_type == UserDefinedDataStreamFunction.WINDOW: window_operation_descriptor = payload window_assigner = window_operation_descriptor.assigner window_trigger = window_operation_descriptor.trigger allowed_lateness = window_operation_descriptor.allowed_lateness window_state_descriptor = window_operation_descriptor.window_state_descriptor internal_window_function = window_operation_descriptor.internal_window_function window_serializer = window_operation_descriptor.window_serializer keyed_state_backend._namespace_coder_impl = window_serializer._get_coder( ) window_operator = WindowOperator(window_assigner, keyed_state_backend, user_key_selector, window_state_descriptor, internal_window_function, window_trigger, allowed_lateness) output_factory = RowWithTimerOutputFactory(window_serializer) def open_func(): window_operator.open(runtime_context, internal_timer_service) def close_func(): window_operator.close() input_handler = OneInputRowWithTimerHandler( internal_timer_service, keyed_state_backend, state_key_selector, lambda n, t: window_operator.process_element(input_selector(n), t), window_operator.on_event_time, window_operator.on_processing_time, output_factory) process_element_func = input_handler.accept else: raise Exception("Unsupported func_type: " + str(func_type)) return process_element_func, open_func, close_func