示例#1
0
def _create_user_defined_function_operation(factory, transform_proto,
                                            consumers, udfs_proto,
                                            beam_operation_cls,
                                            internal_operation_cls):
    output_tags = list(transform_proto.outputs.keys())
    output_coders = factory.get_output_coders(transform_proto)
    spec = operation_specs.WorkerDoFn(
        serialized_fn=udfs_proto,
        output_tags=output_tags,
        input=None,
        side_inputs=None,
        output_coders=[output_coders[tag] for tag in output_tags])
    name = common.NameContext(transform_proto.unique_name)

    serialized_fn = spec.serialized_fn
    if hasattr(serialized_fn, "key_type"):
        # keyed operation, need to create the KeyedStateBackend.
        row_schema = serialized_fn.key_type.row_schema
        key_row_coder = FlattenRowCoder(
            [from_proto(f.type) for f in row_schema.fields])
        if serialized_fn.HasField('group_window'):
            if serialized_fn.group_window.is_time_window:
                window_coder = TimeWindowCoder()
            else:
                window_coder = CountWindowCoder()
        else:
            window_coder = None
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, window_coder,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)

        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    elif internal_operation_cls == datastream_operations.StatefulOperation:
        key_row_coder = from_type_info_proto(serialized_fn.key_type_info)
        keyed_state_backend = RemoteKeyedStateBackend(
            factory.state_handler, key_row_coder, None,
            serialized_fn.state_cache_size,
            serialized_fn.map_state_read_cache_size,
            serialized_fn.map_state_write_cache_size)
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls, keyed_state_backend)
    else:
        return beam_operation_cls(name, spec, factory.counter_factory,
                                  factory.state_sampler, consumers,
                                  internal_operation_cls)
示例#2
0
 def test_window_coder(self):
     coder = TimeWindowCoder()
     self.check_coder(coder, TimeWindow(100, 1000))
     coder = CountWindowCoder()
     self.check_coder(coder, CountWindow(100))