Exemplo n.º 1
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            beam_operation_cls,
                                            internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])
    name = common.NameContext(transform_proto.unique_name)

    serialized_fn = spec.serialized_fn
    if hasattr(serialized_fn, "key_type"):
        # keyed operation, need to create the KeyedStateBackend.
        row_schema = serialized_fn.key_type.row_schema
        key_row_coder = FlattenRowCoder(
            [from_proto(f.type) for f in row_schema.fields])
        if serialized_fn.HasField('group_window'):
            if serialized_fn.group_window.is_time_window:
                window_coder = TimeWindowCoder()
            else:
                window_coder = CountWindowCoder()
        else:
            window_coder = None
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, window_coder,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)

        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    elif internal_operation_cls == datastream_operations.StatefulOperation:
        key_row_coder = from_type_info_proto(serialized_fn.key_type_info)
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, None,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    else:
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls)
Exemplo n.º 2
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            beam_operation_cls,
                                            internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])

    if hasattr(spec.serialized_fn, "key_type"):
        # keyed operation, need to create the KeyedStateBackend.
        key_row_coder = from_proto(spec.serialized_fn.key_type)
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder,
            spec.serialized_fn.state_cache_size,
            spec.serialized_fn.map_state_read_cache_size,
            spec.serialized_fn.map_state_write_cache_size)

        return beam_operation_cls(transform_proto.unique_name, spec,
                                  factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    else:
        return beam_operation_cls(transform_proto.unique_name, spec,
                                  factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls)
Exemplo n.º 3
0
def _create_stateful_user_defined_function_operation(factory, transform_proto,
                                                     consumers, udfs_proto,
                                                     beam_operation_cls,
                                                     internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_coders])
    key_type_info = spec.serialized_fn.key_type_info
    key_row_coder = from_type_info_proto(key_type_info.field[0].type)
    keyed_state_backend = RemoteKeyedStateBackend(factory.state_handler,
                                                  key_row_coder, 1000, 1000,
                                                  1000)

    return beam_operation_cls(transform_proto.unique_name, spec,
                              factory.counter_factory, factory.state_sampler,
                              consumers, internal_operation_cls,
                              keyed_state_backend)
Exemplo n.º 4
0
def extract_stateful_function(user_defined_function_proto,
                              runtime_context: RuntimeContext,
                              keyed_state_backend: RemoteKeyedStateBackend):
    func_type = user_defined_function_proto.function_type
    user_defined_func = pickle.loads(user_defined_function_proto.payload)
    internal_timer_service = InternalTimerServiceImpl(keyed_state_backend)

    def state_key_selector(normal_data):
        return Row(normal_data[0])

    def user_key_selector(normal_data):
        return normal_data[0]

    def input_selector(normal_data):
        return normal_data[1]

    UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
    if func_type in (UserDefinedDataStreamFunction.KEYED_PROCESS,
                     UserDefinedDataStreamFunction.KEYED_CO_PROCESS):
        timer_service = TimerServiceImpl(internal_timer_service)
        ctx = InternalKeyedProcessFunctionContext(timer_service)
        on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext(
            timer_service)
        process_function = user_defined_func
        internal_timer_service.set_namespace_serializer(
            VoidNamespaceSerializer())

        def open_func():
            if hasattr(process_function, "open"):
                process_function.open(runtime_context)

        def close_func():
            if hasattr(process_function, "close"):
                process_function.close()

        def on_event_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return _on_timer(TimeDomain.EVENT_TIME, timestamp, key)

        def on_processing_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return _on_timer(TimeDomain.PROCESSING_TIME, timestamp, key)

        def _on_timer(time_domain: TimeDomain, timestamp: int, key):
            user_current_key = user_key_selector(key)

            on_timer_ctx.set_timestamp(timestamp)
            on_timer_ctx.set_current_key(user_current_key)
            on_timer_ctx.set_time_domain(time_domain)

            return process_function.on_timer(timestamp, on_timer_ctx)

        if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:

            def process_element(normal_data, timestamp: int):
                ctx.set_timestamp(timestamp)
                ctx.set_current_key(user_key_selector(normal_data))
                keyed_state_backend.set_current_key(
                    state_key_selector(normal_data))
                return process_function.process_element(
                    input_selector(normal_data), ctx)

        elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:

            def process_element(normal_data, timestamp: int):
                is_left = normal_data[0]
                if is_left:
                    user_input = normal_data[1]
                else:
                    user_input = normal_data[2]

                ctx.set_timestamp(timestamp)
                on_timer_ctx.set_current_key(user_key_selector(user_input))
                keyed_state_backend.set_current_key(
                    state_key_selector(user_input))

                if is_left:
                    return process_function.process_element1(
                        input_selector(user_input), ctx)
                else:
                    return process_function.process_element2(
                        input_selector(user_input), ctx)

        else:
            raise Exception("Unsupported func_type: " + str(func_type))

    elif func_type == UserDefinedDataStreamFunction.WINDOW:
        window_operation_descriptor = user_defined_func
        window_assigner = window_operation_descriptor.assigner
        window_trigger = window_operation_descriptor.trigger
        allowed_lateness = window_operation_descriptor.allowed_lateness
        window_state_descriptor = window_operation_descriptor.window_state_descriptor
        internal_window_function = window_operation_descriptor.internal_window_function
        window_serializer = window_operation_descriptor.window_serializer
        window_coder = window_serializer._get_coder()
        keyed_state_backend.namespace_coder = window_coder
        keyed_state_backend._namespace_coder_impl = window_coder.get_impl()
        window_operator = WindowOperator(window_assigner, keyed_state_backend,
                                         user_key_selector,
                                         window_state_descriptor,
                                         internal_window_function,
                                         window_trigger, allowed_lateness)
        internal_timer_service.set_namespace_serializer(window_serializer)

        def open_func():
            window_operator.open(runtime_context, internal_timer_service)

        def close_func():
            window_operator.close()

        def process_element(normal_data, timestamp: int):
            keyed_state_backend.set_current_key(
                state_key_selector(normal_data))
            return window_operator.process_element(input_selector(normal_data),
                                                   timestamp)

        def on_event_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return window_operator.on_event_time(timestamp, key, namespace)

        def on_processing_time(timestamp: int, key, namespace):
            keyed_state_backend.set_current_key(key)
            return window_operator.on_processing_time(timestamp, key,
                                                      namespace)

    else:
        raise Exception("Unsupported function_type: " + str(func_type))

    input_handler = RunnerInputHandler(internal_timer_service, process_element)
    process_element_func = input_handler.process_element

    timer_handler = TimerHandler(internal_timer_service, on_event_time,
                                 on_processing_time,
                                 keyed_state_backend._namespace_coder_impl)
    process_timer_func = timer_handler.process_timer

    return open_func, close_func, process_element_func, process_timer_func, internal_timer_service
Exemplo n.º 5
0
def extract_keyed_stateful_function(
        user_defined_function_proto,
        keyed_state_backend: RemoteKeyedStateBackend,
        runtime_context: RuntimeContext):
    func_type = user_defined_function_proto.function_type
    UserDefinedDataStreamFunction = flink_fn_execution_pb2.UserDefinedDataStreamFunction
    payload = pickle.loads(user_defined_function_proto.payload)
    internal_timer_service = InternalTimerServiceImpl(keyed_state_backend)

    def state_key_selector(normal_data):
        return Row(normal_data[0])

    def user_key_selector(normal_data):
        return normal_data[0]

    def input_selector(normal_data):
        return normal_data[1]

    if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS or \
            func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
        timer_service = TimerServiceImpl(internal_timer_service)
        on_timer_ctx = InternalKeyedProcessFunctionOnTimerContext(
            timer_service)
        ctx = InternalKeyedProcessFunctionContext(timer_service)
        process_function = payload
        output_factory = RowWithTimerOutputFactory(VoidNamespaceSerializer())

        def open_func():
            if hasattr(process_function, "open"):
                process_function.open(runtime_context)

        def close_func():
            if hasattr(process_function, "close"):
                process_function.close()

        if func_type == UserDefinedDataStreamFunction.KEYED_PROCESS:

            def process_element(normal_data, timestamp: int):
                ctx.set_timestamp(timestamp)
                user_current_key = user_key_selector(normal_data)
                ctx.set_current_key(user_current_key)
                return process_function.process_element(
                    input_selector(normal_data), ctx)

            def on_event_time(internal_timer: InternalTimerImpl):
                timestamp = internal_timer.get_timestamp()
                state_current_key = internal_timer.get_key()
                user_current_key = user_key_selector(state_current_key)

                on_timer_ctx.set_current_key(user_current_key)
                on_timer_ctx.set_time_domain(TimeDomain.EVENT_TIME)

                return process_function.on_timer(timestamp, on_timer_ctx)

            def on_processing_time(internal_timer: InternalTimerImpl):
                timestamp = internal_timer.get_timestamp()
                state_current_key = internal_timer.get_key()
                user_current_key = user_key_selector(state_current_key)

                on_timer_ctx.set_current_key(user_current_key)
                on_timer_ctx.set_time_domain(TimeDomain.PROCESSING_TIME)

                return process_function.on_timer(timestamp, on_timer_ctx)

            input_handler = OneInputRowWithTimerHandler(
                internal_timer_service, keyed_state_backend,
                state_key_selector, process_element, on_event_time,
                on_processing_time, output_factory)

            process_element_func = input_handler.accept
        elif func_type == UserDefinedDataStreamFunction.KEYED_CO_PROCESS:
            input_handler = TwoInputRowWithTimerHandler(
                ctx, on_timer_ctx, timer_service, keyed_state_backend,
                process_function, output_factory)

            process_element_func = input_handler.accept
        else:
            raise Exception("Unsupported func_type: " + str(func_type))
    elif func_type == UserDefinedDataStreamFunction.WINDOW:
        window_operation_descriptor = payload
        window_assigner = window_operation_descriptor.assigner
        window_trigger = window_operation_descriptor.trigger
        allowed_lateness = window_operation_descriptor.allowed_lateness
        window_state_descriptor = window_operation_descriptor.window_state_descriptor
        internal_window_function = window_operation_descriptor.internal_window_function
        window_serializer = window_operation_descriptor.window_serializer
        keyed_state_backend._namespace_coder_impl = window_serializer._get_coder(
        )
        window_operator = WindowOperator(window_assigner, keyed_state_backend,
                                         user_key_selector,
                                         window_state_descriptor,
                                         internal_window_function,
                                         window_trigger, allowed_lateness)
        output_factory = RowWithTimerOutputFactory(window_serializer)

        def open_func():
            window_operator.open(runtime_context, internal_timer_service)

        def close_func():
            window_operator.close()

        input_handler = OneInputRowWithTimerHandler(
            internal_timer_service, keyed_state_backend, state_key_selector,
            lambda n, t: window_operator.process_element(input_selector(n), t),
            window_operator.on_event_time, window_operator.on_processing_time,
            output_factory)

        process_element_func = input_handler.accept
    else:
        raise Exception("Unsupported func_type: " + str(func_type))

    return process_element_func, open_func, close_func