Ejemplo n.º 1
0
def extract_stateful_function(user_defined_function_proto,
                              runtime_context: RuntimeContext,
                              keyed_state_backend: RemoteKeyedStateBackend):
    func_type = user_defined_function_proto.function_type
    user_defined_func = pickle.loads(user_defined_function_proto.payload)
    internal_timer_service = InternalTimerServiceImpl(keyed_state_backend)

    def state_key_selector(normal_data):
        return Row(normal_data[0])

    def user_key_selector(normal_data):
        return normal_data[0]

    def input_selector(normal_data):
        return normal_data[1]

    UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
    if func_type in (UserDefinedDataStreamFunction.KEYED_PROCESS,
                     UserDefinedDataStreamFunction.KEYED_CO_PROCESS):
        timer_service = TimerServiceImpl(internal_timer_service)
        ctx = InternalKeyedProcessFunctionContext(timer_service)
        on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext(
            timer_service)
        process_function = user_defined_func
        internal_timer_service.set_namespace_serializer(
            VoidNamespaceSerializer())

        def open_func():
            if hasattr(process_function, "open"):
                process_function.open(runtime_context)

        def close_func():
            if hasattr(process_function, "close"):
                process_function.close()

        def on_event_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return _on_timer(TimeDomain.EVENT_TIME, timestamp, key)

        def on_processing_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return _on_timer(TimeDomain.PROCESSING_TIME, timestamp, key)

        def _on_timer(time_domain: TimeDomain, timestamp: int, key):
            user_current_key = user_key_selector(key)

            on_timer_ctx.set_timestamp(timestamp)
            on_timer_ctx.set_current_key(user_current_key)
            on_timer_ctx.set_time_domain(time_domain)

            return process_function.on_timer(timestamp, on_timer_ctx)

        if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:

            def process_element(normal_data, timestamp: int):
                ctx.set_timestamp(timestamp)
                ctx.set_current_key(user_key_selector(normal_data))
                keyed_state_backend.set_current_key(
                    state_key_selector(normal_data))
                return process_function.process_element(
                    input_selector(normal_data), ctx)

        elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:

            def process_element(normal_data, timestamp: int):
                is_left = normal_data[0]
                if is_left:
                    user_input = normal_data[1]
                else:
                    user_input = normal_data[2]

                ctx.set_timestamp(timestamp)
                on_timer_ctx.set_current_key(user_key_selector(user_input))
                keyed_state_backend.set_current_key(
                    state_key_selector(user_input))

                if is_left:
                    return process_function.process_element1(
                        input_selector(user_input), ctx)
                else:
                    return process_function.process_element2(
                        input_selector(user_input), ctx)

        else:
            raise Exception("Unsupported func_type: " + str(func_type))

    elif func_type == UserDefinedDataStreamFunction.WINDOW:
        window_operation_descriptor = user_defined_func
        window_assigner = window_operation_descriptor.assigner
        window_trigger = window_operation_descriptor.trigger
        allowed_lateness = window_operation_descriptor.allowed_lateness
        window_state_descriptor = window_operation_descriptor.window_state_descriptor
        internal_window_function = window_operation_descriptor.internal_window_function
        window_serializer = window_operation_descriptor.window_serializer
        window_coder = window_serializer._get_coder()
        keyed_state_backend.namespace_coder = window_coder
        keyed_state_backend._namespace_coder_impl = window_coder.get_impl()
        window_operator = WindowOperator(window_assigner, keyed_state_backend,
                                         user_key_selector,
                                         window_state_descriptor,
                                         internal_window_function,
                                         window_trigger, allowed_lateness)
        internal_timer_service.set_namespace_serializer(window_serializer)

        def open_func():
            window_operator.open(runtime_context, internal_timer_service)

        def close_func():
            window_operator.close()

        def process_element(normal_data, timestamp: int):
            keyed_state_backend.set_current_key(
                state_key_selector(normal_data))
            return window_operator.process_element(input_selector(normal_data),
                                                   timestamp)

        def on_event_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return window_operator.on_event_time(timestamp, key, namespace)

        def on_processing_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return window_operator.on_processing_time(timestamp, key,
                                                      namespace)

    else:
        raise Exception("Unsupported function_type: " + str(func_type))

    input_handler = RunnerInputHandler(internal_timer_service, process_element)
    process_element_func = input_handler.process_element

    timer_handler = TimerHandler(internal_timer_service, on_event_time,
                                 on_processing_time,
                                 keyed_state_backend._namespace_coder_impl)
    process_timer_func = timer_handler.process_timer

    return open_func, close_func, process_element_func, process_timer_func, internal_timer_service
Ejemplo n.º 2
0
def extract_keyed_stateful_function(
        user_defined_function_proto,
        keyed_state_backend: RemoteKeyedStateBackend,
        runtime_context: RuntimeContext):
    func_type = user_defined_function_proto.function_type
    UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
    payload = pickle.loads(user_defined_function_proto.payload)
    internal_timer_service = InternalTimerServiceImpl(keyed_state_backend)

    def state_key_selector(normal_data):
        return Row(normal_data[0])

    def user_key_selector(normal_data):
        return normal_data[0]

    def input_selector(normal_data):
        return normal_data[1]

    if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS or \
            func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
        timer_service = TimerServiceImpl(internal_timer_service)
        on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext(
            timer_service)
        ctx = InternalKeyedProcessFunctionContext(timer_service)
        process_function = payload
        output_factory = RowWithTimerOutputFactory(VoidNamespaceSerializer())

        def open_func():
            if hasattr(process_function, "open"):
                process_function.open(runtime_context)

        def close_func():
            if hasattr(process_function, "close"):
                process_function.close()

        if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:

            def process_element(normal_data, timestamp: int):
                ctx.set_timestamp(timestamp)
                user_current_key = user_key_selector(normal_data)
                ctx.set_current_key(user_current_key)
                return process_function.process_element(
                    input_selector(normal_data), ctx)

            def on_event_time(internal_timer: InternalTimerImpl):
                timestamp = internal_timer.get_timestamp()
                state_current_key = internal_timer.get_key()
                user_current_key = user_key_selector(state_current_key)

                on_timer_ctx.set_current_key(user_current_key)
                on_timer_ctx.set_time_domain(TimeDomain.EVENT_TIME)

                return process_function.on_timer(timestamp, on_timer_ctx)

            def on_processing_time(internal_timer: InternalTimerImpl):
                timestamp = internal_timer.get_timestamp()
                state_current_key = internal_timer.get_key()
                user_current_key = user_key_selector(state_current_key)

                on_timer_ctx.set_current_key(user_current_key)
                on_timer_ctx.set_time_domain(TimeDomain.PROCESSING_TIME)

                return process_function.on_timer(timestamp, on_timer_ctx)

            input_handler = OneInputRowWithTimerHandler(
                internal_timer_service, keyed_state_backend,
                state_key_selector, process_element, on_event_time,
                on_processing_time, output_factory)

            process_element_func = input_handler.accept
        elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
            input_handler = TwoInputRowWithTimerHandler(
                ctx, on_timer_ctx, timer_service, keyed_state_backend,
                process_function, output_factory)

            process_element_func = input_handler.accept
        else:
            raise Exception("Unsupported func_type: " + str(func_type))
    elif func_type == UserDefinedDataStreamFunction.WINDOW:
        window_operation_descriptor = payload
        window_assigner = window_operation_descriptor.assigner
        window_trigger = window_operation_descriptor.trigger
        allowed_lateness = window_operation_descriptor.allowed_lateness
        window_state_descriptor = window_operation_descriptor.window_state_descriptor
        internal_window_function = window_operation_descriptor.internal_window_function
        window_serializer = window_operation_descriptor.window_serializer
        keyed_state_backend._namespace_coder_impl = window_serializer._get_coder(
        )
        window_operator = WindowOperator(window_assigner, keyed_state_backend,
                                         user_key_selector,
                                         window_state_descriptor,
                                         internal_window_function,
                                         window_trigger, allowed_lateness)
        output_factory = RowWithTimerOutputFactory(window_serializer)

        def open_func():
            window_operator.open(runtime_context, internal_timer_service)

        def close_func():
            window_operator.close()

        input_handler = OneInputRowWithTimerHandler(
            internal_timer_service, keyed_state_backend, state_key_selector,
            lambda n, t: window_operator.process_element(input_selector(n), t),
            window_operator.on_event_time, window_operator.on_processing_time,
            output_factory)

        process_element_func = input_handler.accept
    else:
        raise Exception("Unsupported func_type: " + str(func_type))

    return process_element_func, open_func, close_func