Example #1
0
class RemoteKeyedStateBackend(object):
    """
    A keyed state backend provides methods for managing keyed state.
    """
    def __init__(self, state_handler, key_coder, state_cache_size,
                 map_state_read_cache_size, map_state_write_cache_size):
        self._state_handler = state_handler
        self._map_state_handler = CachingMapStateHandler(
            state_handler, map_state_read_cache_size)
        from pyflink.fn_execution.coders import FlattenRowCoder
        self._key_coder_impl = FlattenRowCoder(
            key_coder._field_coders).get_impl()
        self._state_cache_size = state_cache_size
        self._map_state_write_cache_size = map_state_write_cache_size
        self._all_states = {}
        self._internal_state_cache = LRUCache(self._state_cache_size, None)
        self._internal_state_cache.set_on_evict(
            lambda key, value: self.commit_internal_state(value))
        self._current_key = None
        self._encoded_current_key = None
        self._clear_iterator_mark = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="clear_iterators",
                side_input_id="clear_iterators",
                key=self._encoded_current_key))

    def get_list_state(self, name, element_coder):
        if name in self._all_states:
            self.validate_list_state(name, element_coder)
            return self._all_states[name]
        internal_bag_state = self._get_internal_bag_state(name, element_coder)
        list_state = SynchronousListRuntimeState(internal_bag_state)
        self._all_states[name] = list_state
        return list_state

    def get_value_state(self, name, value_coder):
        if name in self._all_states:
            self.validate_value_state(name, value_coder)
            return self._all_states[name]
        internal_bag_state = self._get_internal_bag_state(name, value_coder)
        value_state = SynchronousValueRuntimeState(internal_bag_state)
        self._all_states[name] = value_state
        return value_state

    def get_map_state(self, name, map_key_coder, map_value_coder):
        if name in self._all_states:
            self.validate_map_state(name, map_key_coder, map_value_coder)
            return self._all_states[name]
        internal_map_state = self._get_internal_map_state(
            name, map_key_coder, map_value_coder)
        map_state = SynchronousMapRuntimeState(internal_map_state)
        self._all_states[name] = map_state
        return map_state

    def validate_value_state(self, name, coder):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, SynchronousValueRuntimeState):
                raise Exception(
                    "The state name '%s' is already in use and not a value state."
                    % name)
            if state._internal_state._value_coder != coder:
                raise Exception("State name corrupted: %s" % name)

    def validate_list_state(self, name, coder):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, SynchronousListRuntimeState):
                raise Exception(
                    "The state name '%s' is already in use and not a list state."
                    % name)
            if state._internal_state._value_coder != coder:
                raise Exception("State name corrupted: %s" % name)

    def validate_map_state(self, name, map_key_coder, map_value_coder):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, SynchronousMapRuntimeState):
                raise Exception(
                    "The state name '%s' is already in use and not a map state."
                    % name)
            if state._internal_state._map_key_coder != map_key_coder or \
                    state._internal_state._map_value_coder != map_value_coder:
                raise Exception("State name corrupted: %s" % name)

    def _get_internal_bag_state(self, name, element_coder):
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key))
        if cached_state is not None:
            return cached_state
        state_spec = userstate.BagStateSpec(name, element_coder)
        internal_state = self._create_bag_state(state_spec)
        return internal_state

    def _get_internal_map_state(self, name, map_key_coder, map_value_coder):
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key))
        if cached_state is not None:
            return cached_state
        internal_map_state = self._create_internal_map_state(
            name, map_key_coder, map_value_coder)
        return internal_map_state

    def _create_bag_state(self, state_spec: userstate.StateSpec) \
            -> userstate.AccumulatingRuntimeState:
        if isinstance(state_spec, userstate.BagStateSpec):
            bag_state = SynchronousBagRuntimeState(
                self._state_handler,
                state_key=beam_fn_api_pb2.StateKey(
                    bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                        transform_id="",
                        user_state_id=state_spec.name,
                        key=self._encoded_current_key)),
                value_coder=state_spec.coder)
            return bag_state
        else:
            raise NotImplementedError(state_spec)

    def _create_internal_map_state(self, name, map_key_coder, map_value_coder):
        # Currently the `beam_fn_api.proto` does not support MapState, so we use the
        # the `MultimapSideInput` message to mark the state as a MapState for now.
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="",
                side_input_id=name,
                key=self._encoded_current_key))
        return InternalSynchronousMapRuntimeState(
            self._map_state_handler, state_key, map_key_coder, map_value_coder,
            self._map_state_write_cache_size)

    def set_current_key(self, key):
        if key == self._current_key:
            return
        encoded_old_key = self._encoded_current_key
        self._current_key = key
        self._encoded_current_key = self._key_coder_impl.encode_nested(
            self._current_key)
        for state_name, state_obj in self._all_states.items():
            if self._state_cache_size > 0:
                # cache old internal state
                self._internal_state_cache.put((state_name, encoded_old_key),
                                               state_obj._internal_state)
            if isinstance(
                    state_obj,
                (SynchronousValueRuntimeState, SynchronousListRuntimeState)):
                state_obj._internal_state = self._get_internal_bag_state(
                    state_name, state_obj._internal_state._value_coder)
            elif isinstance(state_obj, SynchronousMapRuntimeState):
                state_obj._internal_state = self._get_internal_map_state(
                    state_name, state_obj._internal_state._map_key_coder,
                    state_obj._internal_state._map_value_coder)
            else:
                raise Exception("Unknown internal state '%s': %s" %
                                (state_name, state_obj))

    def get_current_key(self):
        return self._current_key

    def commit(self):
        for internal_state in self._internal_state_cache:
            self.commit_internal_state(internal_state)
        for name, state in self._all_states.items():
            if (name, self._encoded_current_key
                ) not in self._internal_state_cache:
                self.commit_internal_state(state._internal_state)

    def clear_cached_iterators(self):
        if self._map_state_handler.get_cached_iterators_num() > 0:
            self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key
            self._map_state_handler.clear(self._clear_iterator_mark)

    @staticmethod
    def commit_internal_state(internal_state):
        internal_state.commit()
        # reset the status of the internal state to reuse the object cross bundle
        if isinstance(internal_state, SynchronousBagRuntimeState):
            internal_state._cleared = False
            internal_state._added_elements = []
Example #2
0
class RemoteKeyedStateBackend(object):
    """
    A keyed state backend provides methods for managing keyed state.
    """

    def __init__(self,
                 state_handler,
                 key_coder,
                 namespace_coder,
                 state_cache_size,
                 map_state_read_cache_size,
                 map_state_write_cache_size):
        self._state_handler = state_handler
        self._map_state_handler = CachingMapStateHandler(
            state_handler, map_state_read_cache_size)
        from pyflink.fn_execution.coders import FlattenRowCoder
        self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl()
        self.namespace_coder = namespace_coder
        if namespace_coder:
            self._namespace_coder_impl = namespace_coder.get_impl()
        else:
            self._namespace_coder_impl = None
        self._state_cache_size = state_cache_size
        self._map_state_write_cache_size = map_state_write_cache_size
        self._all_states = {}
        self._internal_state_cache = LRUCache(self._state_cache_size, None)
        self._internal_state_cache.set_on_evict(
            lambda key, value: self.commit_internal_state(value))
        self._current_key = None
        self._encoded_current_key = None
        self._clear_iterator_mark = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="clear_iterators",
                side_input_id="clear_iterators",
                key=self._encoded_current_key))

    def get_list_state(self, name, element_coder):
        return self._wrap_internal_bag_state(
            name, element_coder, SynchronousListRuntimeState, SynchronousListRuntimeState)

    def get_value_state(self, name, value_coder):
        return self._wrap_internal_bag_state(
            name, value_coder, SynchronousValueRuntimeState, SynchronousValueRuntimeState)

    def get_map_state(self, name, map_key_coder, map_value_coder):
        if name in self._all_states:
            self.validate_map_state(name, map_key_coder, map_value_coder)
            return self._all_states[name]
        map_state = SynchronousMapRuntimeState(name, map_key_coder, map_value_coder, self)
        self._all_states[name] = map_state
        return map_state

    def get_reducing_state(self, name, coder, reduce_function):
        return self._wrap_internal_bag_state(
            name, coder, SynchronousReducingRuntimeState,
            partial(SynchronousReducingRuntimeState, reduce_function=reduce_function))

    def get_aggregating_state(self, name, coder, agg_function):
        return self._wrap_internal_bag_state(
            name, coder, SynchronousAggregatingRuntimeState,
            partial(SynchronousAggregatingRuntimeState, agg_function=agg_function))

    def validate_state(self, name, coder, expected_type):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, expected_type):
                raise Exception("The state name '%s' is already in use and not a %s."
                                % (name, expected_type))
            if state._value_coder != coder:
                raise Exception("State name corrupted: %s" % name)

    def validate_map_state(self, name, map_key_coder, map_value_coder):
        if name in self._all_states:
            state = self._all_states[name]
            if not isinstance(state, SynchronousMapRuntimeState):
                raise Exception("The state name '%s' is already in use and not a map state."
                                % name)
            if state._map_key_coder != map_key_coder or \
                    state._map_value_coder != map_value_coder:
                raise Exception("State name corrupted: %s" % name)

    def _wrap_internal_bag_state(self, name, element_coder, wrapper_type, wrap_method):
        if name in self._all_states:
            self.validate_state(name, element_coder, wrapper_type)
            return self._all_states[name]
        wrapped_state = wrap_method(name, element_coder, self)
        self._all_states[name] = wrapped_state
        return wrapped_state

    def _get_internal_bag_state(self, name, namespace, element_coder):
        encoded_namespace = self._encode_namespace(namespace)
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key, encoded_namespace))
        if cached_state is not None:
            return cached_state
        # The created internal state would not be put into the internal state cache
        # at once. The internal state cache is only updated when the current key changes.
        # The reason is that the state cache size may be smaller that the count of activated
        # state (i.e. the state with current key).
        state_spec = userstate.BagStateSpec(name, element_coder)
        internal_state = self._create_bag_state(state_spec, encoded_namespace)
        return internal_state

    def _get_internal_map_state(self, name, namespace, map_key_coder, map_value_coder):
        encoded_namespace = self._encode_namespace(namespace)
        cached_state = self._internal_state_cache.get(
            (name, self._encoded_current_key, encoded_namespace))
        if cached_state is not None:
            return cached_state
        internal_map_state = self._create_internal_map_state(
            name, encoded_namespace, map_key_coder, map_value_coder)
        return internal_map_state

    def _create_bag_state(self, state_spec: userstate.StateSpec, encoded_namespace) \
            -> userstate.AccumulatingRuntimeState:
        if isinstance(state_spec, userstate.BagStateSpec):
            bag_state = SynchronousBagRuntimeState(
                self._state_handler,
                state_key=beam_fn_api_pb2.StateKey(
                    bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                        transform_id="",
                        window=encoded_namespace,
                        user_state_id=state_spec.name,
                        key=self._encoded_current_key)),
                value_coder=state_spec.coder)
            return bag_state
        else:
            raise NotImplementedError(state_spec)

    def _create_internal_map_state(self, name, encoded_namespace, map_key_coder, map_value_coder):
        # Currently the `beam_fn_api.proto` does not support MapState, so we use the
        # the `MultimapSideInput` message to mark the state as a MapState for now.
        state_key = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id="",
                window=encoded_namespace,
                side_input_id=name,
                key=self._encoded_current_key))
        return InternalSynchronousMapRuntimeState(
            self._map_state_handler,
            state_key,
            map_key_coder,
            map_value_coder,
            self._map_state_write_cache_size)

    def _encode_namespace(self, namespace):
        if namespace is not None:
            encoded_namespace = self._namespace_coder_impl.encode_nested(namespace)
        else:
            encoded_namespace = b''
        return encoded_namespace

    def cache_internal_state(self, encoded_key, internal_kv_state: SynchronousKvRuntimeState):
        encoded_old_namespace = self._encode_namespace(internal_kv_state.namespace)
        self._internal_state_cache.put(
            (internal_kv_state.name, encoded_key, encoded_old_namespace),
            internal_kv_state.get_internal_state())

    def set_current_key(self, key):
        if key == self._current_key:
            return
        encoded_old_key = self._encoded_current_key
        self._current_key = key
        self._encoded_current_key = self._key_coder_impl.encode_nested(self._current_key)
        for state_name, state_obj in self._all_states.items():
            if self._state_cache_size > 0:
                # cache old internal state
                self.cache_internal_state(encoded_old_key, state_obj)
            state_obj.namespace = None
            state_obj._internal_state = None

    def get_current_key(self):
        return self._current_key

    def commit(self):
        for internal_state in self._internal_state_cache:
            self.commit_internal_state(internal_state)
        for name, state in self._all_states.items():
            if (name, self._encoded_current_key) not in self._internal_state_cache:
                self.commit_internal_state(state._internal_state)

    def clear_cached_iterators(self):
        if self._map_state_handler.get_cached_iterators_num() > 0:
            self._clear_iterator_mark.multimap_side_input.key = self._encoded_current_key
            self._map_state_handler.clear(self._clear_iterator_mark)

    @staticmethod
    def commit_internal_state(internal_state):
        if internal_state is not None:
            internal_state.commit()
        # reset the status of the internal state to reuse the object cross bundle
        if isinstance(internal_state, SynchronousBagRuntimeState):
            internal_state._cleared = False
            internal_state._added_elements = []
Example #3
0
class RemoteKeyedStateBackend(object):
    """
    A keyed state backend provides methods for managing keyed state.
    """

    def __init__(self, state_handler, key_coder, state_cache_size):
        self._state_handler = state_handler

        try:
            from pyflink.fn_execution import coder_impl_fast
            is_fast = True if coder_impl_fast else False
        except:
            is_fast = False
        if not is_fast:
            self._key_coder_impl = key_coder.get_impl()
        else:
            from pyflink.fn_execution.coders import FlattenRowCoder
            self._key_coder_impl = FlattenRowCoder(key_coder._field_coders).get_impl()
        self._state_cache_size = state_cache_size
        self._all_states = {}
        self._all_internal_states = LRUCache(self._state_cache_size, None)
        self._all_internal_states.set_on_evict(lambda k, v: v.commit())
        self._current_key = None
        self._encoded_current_key = None

    def get_value_state(self, name, value_coder):
        if name in self._all_states:
            self.validate_state(name, value_coder)
            return self._all_states[name]
        internal_bag_state = self._get_internal_bag_state(name, value_coder)
        value_state = SynchronousValueRuntimeState(internal_bag_state)
        self._all_states[name] = value_state
        return value_state

    def validate_state(self, name, coder):
        if name in self._all_states:
            state = self._all_states[name]
            if state._internal_state._value_coder != coder:
                raise ValueError("State name corrupted: %s" % name)

    def _get_internal_bag_state(self, name, element_coder):
        cached_state = self._all_internal_states.get((name, self._current_key))
        if cached_state is not None:
            return cached_state
        state_spec = userstate.BagStateSpec(name, element_coder)
        internal_state = self._create_state(state_spec)
        self._all_internal_states.put((name, self._current_key), internal_state)
        return internal_state

    def _create_state(self, state_spec: userstate.StateSpec) -> userstate.AccumulatingRuntimeState:
        if isinstance(state_spec, userstate.BagStateSpec):
            bag_state = SynchronousBagRuntimeState(
                self._state_handler,
                state_key=beam_fn_api_pb2.StateKey(
                    bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                        transform_id="",
                        user_state_id=state_spec.name,
                        key=self._encoded_current_key)),
                value_coder=state_spec.coder)
            return bag_state
        else:
            raise NotImplementedError(state_spec)

    def set_current_key(self, key):
        self._current_key = key
        self._encoded_current_key = self._key_coder_impl.encode_nested(self._current_key)
        for state_name, state_obj in self._all_states.items():
            state_obj._internal_state = \
                self._get_internal_bag_state(state_name, state_obj._internal_state._value_coder)

    def get_current_key(self):
        return self._current_key

    def commit(self):
        self._all_internal_states.evict_all()
        self._all_states = {}

    def reset(self):
        self._all_internal_states.evict_all()
        self._all_states = {}