def __init__(self, state_handler, key_coder, namespace_coder, state_cache_size, map_state_read_cache_size, map_state_write_cache_size): self._state_handler = state_handler self._map_state_handler = CachingMapStateHandler( state_handler, map_state_read_cache_size) from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl() self.namespace_coder = namespace_coder if namespace_coder: self._namespace_coder_impl = namespace_coder.get_impl() else: self._namespace_coder_impl = None self._state_cache_size = state_cache_size self._map_state_write_cache_size = map_state_write_cache_size self._all_states = {} self._internal_state_cache = LRUCache(self._state_cache_size, None) self._internal_state_cache.set_on_evict( lambda key, value: self.commit_internal_state(value)) self._current_key = None self._encoded_current_key = None self._clear_iterator_mark = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="clear_iterators", side_input_id="clear_iterators", key=self._encoded_current_key))
def __init__(self, state_handler, key_coder, state_cache_size, map_state_read_cache_size, map_state_write_cache_size): self._state_handler = state_handler self._map_state_handler = CachingMapStateHandler( state_handler, map_state_read_cache_size) try: from pyflink.fn_execution import coder_impl_fast is_fast = True if coder_impl_fast else False except: is_fast = False if not is_fast: self._key_coder_impl = key_coder.get_impl() else: from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder( key_coder._field_coders).get_impl() self._state_cache_size = state_cache_size self._map_state_write_cache_size = map_state_write_cache_size self._all_states = {} self._internal_state_cache = LRUCache(self._state_cache_size, None) self._internal_state_cache.set_on_evict( lambda key, value: self.commit_internal_state(value)) self._current_key = None self._encoded_current_key = None
def test_flatten_row_coder(self): field_coder = BigIntCoder() field_count = 10 coder = FlattenRowCoder([field_coder for _ in range(field_count)]).get_impl() v = [None if i % 2 == 0 else i for i in range(field_count)] generator_result = coder.decode(coder.encode(v)) result = [] for item in generator_result: result.append(item) self.assertEqual([v], result)
def __init__(self, state_handler, key_coder, state_cache_size): self._state_handler = state_handler try: from pyflink.fn_execution import coder_impl_fast is_fast = True if coder_impl_fast else False except: is_fast = False if not is_fast: self._key_coder_impl = key_coder.get_impl() else: from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl() self._state_cache_size = state_cache_size self._all_states = {} self._all_internal_states = LRUCache(self._state_cache_size, None) self._all_internal_states.set_on_evict(lambda k, v: v.commit()) self._current_key = None self._encoded_current_key = None
def _create_user_defined_function_operation(factory, transform_proto, consumers, udfs_proto, beam_operation_cls, internal_operation_cls): output_tags = list(transform_proto.outputs.keys()) output_coders = factory.get_output_coders(transform_proto) spec = operation_specs.WorkerDoFn( serialized_fn=udfs_proto, output_tags=output_tags, input=None, side_inputs=None, output_coders=[output_coders[tag] for tag in output_tags]) name = common.NameContext(transform_proto.unique_name) serialized_fn = spec.serialized_fn if hasattr(serialized_fn, "key_type"): # keyed operation, need to create the KeyedStateBackend. row_schema = serialized_fn.key_type.row_schema key_row_coder = FlattenRowCoder( [from_proto(f.type) for f in row_schema.fields]) if serialized_fn.HasField('group_window'): if serialized_fn.group_window.is_time_window: window_coder = TimeWindowCoder() else: window_coder = CountWindowCoder() else: window_coder = None keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, window_coder, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) elif internal_operation_cls == datastream_operations.StatefulOperation: key_row_coder = from_type_info_proto(serialized_fn.key_type_info) keyed_state_backend = RemoteKeyedStateBackend( factory.state_handler, key_row_coder, None, serialized_fn.state_cache_size, serialized_fn.map_state_read_cache_size, serialized_fn.map_state_write_cache_size) return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls, keyed_state_backend) else: return beam_operation_cls(name, spec, factory.counter_factory, factory.state_sampler, consumers, internal_operation_cls)
class RemoteKeyedStateBackend(object): """ A keyed state backend provides methods for managing keyed state. """ def __init__(self, state_handler, key_coder, namespace_coder, state_cache_size, map_state_read_cache_size, map_state_write_cache_size): self._state_handler = state_handler self._map_state_handler = CachingMapStateHandler( state_handler, map_state_read_cache_size) from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl() self.namespace_coder = namespace_coder if namespace_coder: self._namespace_coder_impl = namespace_coder.get_impl() else: self._namespace_coder_impl = None self._state_cache_size = state_cache_size self._map_state_write_cache_size = map_state_write_cache_size self._all_states = {} self._internal_state_cache = LRUCache(self._state_cache_size, None) self._internal_state_cache.set_on_evict( lambda key, value: self.commit_internal_state(value)) self._current_key = None self._encoded_current_key = None self._clear_iterator_mark = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="clear_iterators", side_input_id="clear_iterators", key=self._encoded_current_key)) def get_list_state(self, name, element_coder): return self._wrap_internal_bag_state( name, element_coder, SynchronousListRuntimeState, SynchronousListRuntimeState) def get_value_state(self, name, value_coder): return self._wrap_internal_bag_state( name, value_coder, SynchronousValueRuntimeState, SynchronousValueRuntimeState) def get_map_state(self, name, map_key_coder, map_value_coder): if name in self._all_states: self.validate_map_state(name, map_key_coder, map_value_coder) return self._all_states[name] map_state = SynchronousMapRuntimeState(name, map_key_coder, map_value_coder, self) self._all_states[name] = map_state return map_state def get_reducing_state(self, name, coder, reduce_function): return self._wrap_internal_bag_state( name, coder, SynchronousReducingRuntimeState, partial(SynchronousReducingRuntimeState, reduce_function=reduce_function)) def get_aggregating_state(self, name, coder, agg_function): return self._wrap_internal_bag_state( name, coder, SynchronousAggregatingRuntimeState, partial(SynchronousAggregatingRuntimeState, agg_function=agg_function)) def validate_state(self, name, coder, expected_type): if name in self._all_states: state = self._all_states[name] if not isinstance(state, expected_type): raise Exception("The state name '%s' is already in use and not a %s." % (name, expected_type)) if state._value_coder != coder: raise Exception("State name corrupted: %s" % name) def validate_map_state(self, name, map_key_coder, map_value_coder): if name in self._all_states: state = self._all_states[name] if not isinstance(state, SynchronousMapRuntimeState): raise Exception("The state name '%s' is already in use and not a map state." % name) if state._map_key_coder != map_key_coder or \ state._map_value_coder != map_value_coder: raise Exception("State name corrupted: %s" % name) def _wrap_internal_bag_state(self, name, element_coder, wrapper_type, wrap_method): if name in self._all_states: self.validate_state(name, element_coder, wrapper_type) return self._all_states[name] wrapped_state = wrap_method(name, element_coder, self) self._all_states[name] = wrapped_state return wrapped_state def _get_internal_bag_state(self, name, namespace, element_coder): encoded_namespace = self._encode_namespace(namespace) cached_state = self._internal_state_cache.get( (name, self._encoded_current_key, encoded_namespace)) if cached_state is not None: return cached_state # The created internal state would not be put into the internal state cache # at once. The internal state cache is only updated when the current key changes. # The reason is that the state cache size may be smaller that the count of activated # state (i.e. the state with current key). state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec, encoded_namespace) return internal_state def _get_internal_map_state(self, name, namespace, map_key_coder, map_value_coder): encoded_namespace = self._encode_namespace(namespace) cached_state = self._internal_state_cache.get( (name, self._encoded_current_key, encoded_namespace)) if cached_state is not None: return cached_state internal_map_state = self._create_internal_map_state( name, encoded_namespace, map_key_coder, map_value_coder) return internal_map_state def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace) \ -> userstate.AccumulatingRuntimeState: if isinstance(state_spec, userstate.BagStateSpec): bag_state = SynchronousBagRuntimeState( self._state_handler, state_key=beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( transform_id="", window=encoded_namespace, user_state_id=state_spec.name, key=self._encoded_current_key)), value_coder=state_spec.coder) return bag_state else: raise NotImplementedError(state_spec) def _create_internal_map_state(self, name, encoded_namespace, map_key_coder, map_value_coder): # Currently the `beam_fn_api.proto` does not support MapState, so we use the # the `MultimapSideInput` message to mark the state as a MapState for now. state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="", window=encoded_namespace, side_input_id=name, key=self._encoded_current_key)) return InternalSynchronousMapRuntimeState( self._map_state_handler, state_key, map_key_coder, map_value_coder, self._map_state_write_cache_size) def _encode_namespace(self, namespace): if namespace is not None: encoded_namespace = self._namespace_coder_impl.encode_nested(namespace) else: encoded_namespace = b'' return encoded_namespace def cache_internal_state(self, encoded_key, internal_kv_state: SynchronousKvRuntimeState): encoded_old_namespace = self._encode_namespace(internal_kv_state.namespace) self._internal_state_cache.put( (internal_kv_state.name, encoded_key, encoded_old_namespace), internal_kv_state.get_internal_state()) def set_current_key(self, key): if key == self._current_key: return encoded_old_key = self._encoded_current_key self._current_key = key self._encoded_current_key = self._key_coder_impl.encode_nested(self._current_key) for state_name, state_obj in self._all_states.items(): if self._state_cache_size > 0: # cache old internal state self.cache_internal_state(encoded_old_key, state_obj) state_obj.namespace = None state_obj._internal_state = None def get_current_key(self): return self._current_key def commit(self): for internal_state in self._internal_state_cache: self.commit_internal_state(internal_state) for name, state in self._all_states.items(): if (name, self._encoded_current_key) not in self._internal_state_cache: self.commit_internal_state(state._internal_state) def clear_cached_iterators(self): if self._map_state_handler.get_cached_iterators_num() > 0: self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key self._map_state_handler.clear(self._clear_iterator_mark) @staticmethod def commit_internal_state(internal_state): if internal_state is not None: internal_state.commit() # reset the status of the internal state to reuse the object cross bundle if isinstance(internal_state, SynchronousBagRuntimeState): internal_state._cleared = False internal_state._added_elements = []
class RemoteKeyedStateBackend(object): """ A keyed state backend provides methods for managing keyed state. """ def __init__(self, state_handler, key_coder, state_cache_size, map_state_read_cache_size, map_state_write_cache_size): self._state_handler = state_handler self._map_state_handler = CachingMapStateHandler( state_handler, map_state_read_cache_size) from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder( key_coder._field_coders).get_impl() self._state_cache_size = state_cache_size self._map_state_write_cache_size = map_state_write_cache_size self._all_states = {} self._internal_state_cache = LRUCache(self._state_cache_size, None) self._internal_state_cache.set_on_evict( lambda key, value: self.commit_internal_state(value)) self._current_key = None self._encoded_current_key = None self._clear_iterator_mark = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="clear_iterators", side_input_id="clear_iterators", key=self._encoded_current_key)) def get_list_state(self, name, element_coder): if name in self._all_states: self.validate_list_state(name, element_coder) return self._all_states[name] internal_bag_state = self._get_internal_bag_state(name, element_coder) list_state = SynchronousListRuntimeState(internal_bag_state) self._all_states[name] = list_state return list_state def get_value_state(self, name, value_coder): if name in self._all_states: self.validate_value_state(name, value_coder) return self._all_states[name] internal_bag_state = self._get_internal_bag_state(name, value_coder) value_state = SynchronousValueRuntimeState(internal_bag_state) self._all_states[name] = value_state return value_state def get_map_state(self, name, map_key_coder, map_value_coder): if name in self._all_states: self.validate_map_state(name, map_key_coder, map_value_coder) return self._all_states[name] internal_map_state = self._get_internal_map_state( name, map_key_coder, map_value_coder) map_state = SynchronousMapRuntimeState(internal_map_state) self._all_states[name] = map_state return map_state def validate_value_state(self, name, coder): if name in self._all_states: state = self._all_states[name] if not isinstance(state, SynchronousValueRuntimeState): raise Exception( "The state name '%s' is already in use and not a value state." % name) if state._internal_state._value_coder != coder: raise Exception("State name corrupted: %s" % name) def validate_list_state(self, name, coder): if name in self._all_states: state = self._all_states[name] if not isinstance(state, SynchronousListRuntimeState): raise Exception( "The state name '%s' is already in use and not a list state." % name) if state._internal_state._value_coder != coder: raise Exception("State name corrupted: %s" % name) def validate_map_state(self, name, map_key_coder, map_value_coder): if name in self._all_states: state = self._all_states[name] if not isinstance(state, SynchronousMapRuntimeState): raise Exception( "The state name '%s' is already in use and not a map state." % name) if state._internal_state._map_key_coder != map_key_coder or \ state._internal_state._map_value_coder != map_value_coder: raise Exception("State name corrupted: %s" % name) def _get_internal_bag_state(self, name, element_coder): cached_state = self._internal_state_cache.get( (name, self._encoded_current_key)) if cached_state is not None: return cached_state state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_bag_state(state_spec) return internal_state def _get_internal_map_state(self, name, map_key_coder, map_value_coder): cached_state = self._internal_state_cache.get( (name, self._encoded_current_key)) if cached_state is not None: return cached_state internal_map_state = self._create_internal_map_state( name, map_key_coder, map_value_coder) return internal_map_state def _create_bag_state(self, state_spec: userstate.StateSpec) \ -> userstate.AccumulatingRuntimeState: if isinstance(state_spec, userstate.BagStateSpec): bag_state = SynchronousBagRuntimeState( self._state_handler, state_key=beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( transform_id="", user_state_id=state_spec.name, key=self._encoded_current_key)), value_coder=state_spec.coder) return bag_state else: raise NotImplementedError(state_spec) def _create_internal_map_state(self, name, map_key_coder, map_value_coder): # Currently the `beam_fn_api.proto` does not support MapState, so we use the # the `MultimapSideInput` message to mark the state as a MapState for now. state_key = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id="", side_input_id=name, key=self._encoded_current_key)) return InternalSynchronousMapRuntimeState( self._map_state_handler, state_key, map_key_coder, map_value_coder, self._map_state_write_cache_size) def set_current_key(self, key): if key == self._current_key: return encoded_old_key = self._encoded_current_key self._current_key = key self._encoded_current_key = self._key_coder_impl.encode_nested( self._current_key) for state_name, state_obj in self._all_states.items(): if self._state_cache_size > 0: # cache old internal state self._internal_state_cache.put((state_name, encoded_old_key), state_obj._internal_state) if isinstance( state_obj, (SynchronousValueRuntimeState, SynchronousListRuntimeState)): state_obj._internal_state = self._get_internal_bag_state( state_name, state_obj._internal_state._value_coder) elif isinstance(state_obj, SynchronousMapRuntimeState): state_obj._internal_state = self._get_internal_map_state( state_name, state_obj._internal_state._map_key_coder, state_obj._internal_state._map_value_coder) else: raise Exception("Unknown internal state '%s': %s" % (state_name, state_obj)) def get_current_key(self): return self._current_key def commit(self): for internal_state in self._internal_state_cache: self.commit_internal_state(internal_state) for name, state in self._all_states.items(): if (name, self._encoded_current_key ) not in self._internal_state_cache: self.commit_internal_state(state._internal_state) def clear_cached_iterators(self): if self._map_state_handler.get_cached_iterators_num() > 0: self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key self._map_state_handler.clear(self._clear_iterator_mark) @staticmethod def commit_internal_state(internal_state): internal_state.commit() # reset the status of the internal state to reuse the object cross bundle if isinstance(internal_state, SynchronousBagRuntimeState): internal_state._cleared = False internal_state._added_elements = []
def test_flatten_row_coder(self): field_coder = BigIntCoder() field_count = 10 coder = FlattenRowCoder([field_coder for _ in range(field_count)]) self.check_coder( coder, [None if i % 2 == 0 else i for i in range(field_count)])
class RemoteKeyedStateBackend(object): """ A keyed state backend provides methods for managing keyed state. """ def __init__(self, state_handler, key_coder, state_cache_size): self._state_handler = state_handler try: from pyflink.fn_execution import coder_impl_fast is_fast = True if coder_impl_fast else False except: is_fast = False if not is_fast: self._key_coder_impl = key_coder.get_impl() else: from pyflink.fn_execution.coders import FlattenRowCoder self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl() self._state_cache_size = state_cache_size self._all_states = {} self._all_internal_states = LRUCache(self._state_cache_size, None) self._all_internal_states.set_on_evict(lambda k, v: v.commit()) self._current_key = None self._encoded_current_key = None def get_value_state(self, name, value_coder): if name in self._all_states: self.validate_state(name, value_coder) return self._all_states[name] internal_bag_state = self._get_internal_bag_state(name, value_coder) value_state = SynchronousValueRuntimeState(internal_bag_state) self._all_states[name] = value_state return value_state def validate_state(self, name, coder): if name in self._all_states: state = self._all_states[name] if state._internal_state._value_coder != coder: raise ValueError("State name corrupted: %s" % name) def _get_internal_bag_state(self, name, element_coder): cached_state = self._all_internal_states.get((name, self._current_key)) if cached_state is not None: return cached_state state_spec = userstate.BagStateSpec(name, element_coder) internal_state = self._create_state(state_spec) self._all_internal_states.put((name, self._current_key), internal_state) return internal_state def _create_state(self, state_spec: userstate.StateSpec) -> userstate.AccumulatingRuntimeState: if isinstance(state_spec, userstate.BagStateSpec): bag_state = SynchronousBagRuntimeState( self._state_handler, state_key=beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( transform_id="", user_state_id=state_spec.name, key=self._encoded_current_key)), value_coder=state_spec.coder) return bag_state else: raise NotImplementedError(state_spec) def set_current_key(self, key): self._current_key = key self._encoded_current_key = self._key_coder_impl.encode_nested(self._current_key) for state_name, state_obj in self._all_states.items(): state_obj._internal_state = \ self._get_internal_bag_state(state_name, state_obj._internal_state._value_coder) def get_current_key(self): return self._current_key def commit(self): self._all_internal_states.evict_all() self._all_states = {} def reset(self): self._all_internal_states.evict_all() self._all_states = {}