Exemplo n.º 1
0
    def test_extend_fetches_initial_state(self):
        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class UnderlyingStateHandler(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_value(self, value):
                self._encoded_values = coder.encode(value)

            def get_raw(self, *args):
                return self._encoded_values, None

            def append_raw(self, _key, bytes):
                self._encoded_values += bytes

            def clear(self, *args):
                self._encoded_values = bytes()

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state_handler = UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        def get():
            return list(handler.blocking_get(state, coder_impl, True))

        def append(value):
            handler.extend(state, coder_impl, [value], True)

        def clear():
            handler.clear(state, True)

        # Initialize state
        underlying_state_handler.set_value(42)
        with handler.process_instruction_id('bundle', [cache_token]):
            # Append without reading beforehand
            append(43)
            self.assertEqual(get(), [42, 43])
            clear()
            self.assertEqual(get(), [])
            append(44)
            self.assertEqual(get(), [44])
Exemplo n.º 2
0
    def test_continuation_token(self):
        underlying_state_handler = self.UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        coder = VarIntCoder()

        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        def get(materialize=True):
            result = handler.blocking_get(state, coder.get_impl())
            return list(result) if materialize else result

        def get_type():
            return type(get(materialize=False))

        def append(*values):
            handler.extend(state, coder.get_impl(), values)

        def clear():
            handler.clear(state)

        underlying_state_handler.set_continuations(True)
        underlying_state_handler.set_values([45, 46, 47], coder)
        with handler.process_instruction_id('bundle', [cache_token]):
            self.assertEqual(get_type(),
                             CachingStateHandler.ContinuationIterable)
            self.assertEqual(get(), [45, 46, 47])
            append(48, 49)
            self.assertEqual(get_type(),
                             CachingStateHandler.ContinuationIterable)
            self.assertEqual(get(), [45, 46, 47, 48, 49])
            clear()
            self.assertEqual(get_type(), list)
            self.assertEqual(get(), [])
            append(1)
            self.assertEqual(get(), [1])
            append(2, 3)
            self.assertEqual(get(), [1, 2, 3])
            clear()
            for i in range(1000):
                append(i)
            self.assertEqual(get_type(), list)
            self.assertEqual(get(), [i for i in range(1000)])
Exemplo n.º 3
0
    def test_append_clear_with_preexisting_state(self):
        state = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))

        cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())

        coder = VarIntCoder()

        underlying_state_handler = self.UnderlyingStateHandler()
        state_cache = statecache.StateCache(100)
        handler = sdk_worker.CachingStateHandler(state_cache,
                                                 underlying_state_handler)

        def get():
            return handler.blocking_get(state, coder.get_impl())

        def append(iterable):
            handler.extend(state, coder.get_impl(), iterable)

        def clear():
            handler.clear(state)

        # Initialize state
        underlying_state_handler.set_value(42, coder)
        with handler.process_instruction_id('bundle', [cache_token]):
            # Append without reading beforehand
            append([43])
            self.assertEqual(get(), [42, 43])
            clear()
            self.assertEqual(get(), [])
            append([44, 45])
            self.assertEqual(get(), [44, 45])
            append((46, 47))
            self.assertEqual(get(), [44, 45, 46, 47])
            clear()
            append(range(1000))
            self.assertEqual(get(), list(range(1000)))
Exemplo n.º 4
0
 def __init__(self,
              unused_payload,  # type: None
              state,  # type: sdk_worker.StateHandler
              provision_info,  # type: Optional[ExtendedProvisionInfo]
              worker_manager,  # type: WorkerHandlerManager
             ):
   # type: (...) -> None
   super(EmbeddedWorkerHandler, self).__init__(
       self, data_plane.InMemoryDataChannel(), state, provision_info)
   self.control_conn = self  # type: ignore  # need Protocol to describe this
   self.data_conn = self.data_plane_handler
   state_cache = StateCache(STATE_CACHE_SIZE)
   self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
       SingletonStateHandlerFactory(
           sdk_worker.CachingStateHandler(state_cache, state)),
       data_plane.InMemoryDataChannelFactory(
           self.data_plane_handler.inverse()),
       worker_manager._process_bundle_descriptors)
   self.worker = sdk_worker.SdkWorker(
       self.bundle_processor_cache,
       state_cache_metrics_fn=state_cache.get_monitoring_infos)
   self._uid_counter = 0
Exemplo n.º 5
0
    def test_caching(self):

        coder = VarIntCoder()
        coder_impl = coder.get_impl()

        class FakeUnderlyingState(object):
            """Simply returns an incremented counter as the state "value."
      """
            def set_counter(self, n):
                self._counter = n

            def get_raw(self, *args):
                self._counter += 1
                return coder.encode(self._counter), None

            @contextlib.contextmanager
            def process_instruction_id(self, bundle_id):
                yield

        underlying_state = FakeUnderlyingState()
        state_cache = statecache.StateCache(100)
        caching_state_hander = sdk_worker.CachingStateHandler(
            state_cache, underlying_state)

        state1 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state1'))
        state2 = beam_fn_api_pb2.StateKey(
            bag_user_state=beam_fn_api_pb2.StateKey.BagUserState(
                user_state_id='state2'))
        side1 = beam_fn_api_pb2.StateKey(
            multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput(
                transform_id='transform', side_input_id='side1'))
        side2 = beam_fn_api_pb2.StateKey(
            iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput(
                transform_id='transform', side_input_id='side2'))

        state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token1',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'state_token2',
            user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            UserState())
        side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token1',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))
        side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken(
            token=b'side1_token2',
            side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken.
            SideInput(transform_id='transform', side_input_id='side1'))

        def get_as_list(key):
            return list(caching_state_hander.blocking_get(key, coder_impl))

        underlying_state.set_counter(100)
        with caching_state_hander.process_instruction_id('bundle1', []):
            self.assertEqual(get_as_list(state1), [101])  # uncached
            self.assertEqual(get_as_list(state2), [102])  # uncached
            self.assertEqual(get_as_list(state1), [101])  # cached on bundle
            self.assertEqual(get_as_list(side1), [103])  # uncached
            self.assertEqual(get_as_list(side2), [104])  # uncached

        underlying_state.set_counter(200)
        with caching_state_hander.process_instruction_id(
                'bundle2', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1), [201])  # uncached
            self.assertEqual(get_as_list(state2), [202])  # uncached
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1), [203])  # uncached
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [204])  # uncached
            self.assertEqual(get_as_list(side2), [204])  # cached on bundle

        underlying_state.set_counter(300)
        with caching_state_hander.process_instruction_id(
                'bundle3', [state_token1, side1_token1]):
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(state2),
                             [202])  # cached on state token1
            self.assertEqual(get_as_list(state1),
                             [201])  # cached on state token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [301])  # uncached
            self.assertEqual(get_as_list(side2), [301])  # cached on bundle

        underlying_state.set_counter(400)
        with caching_state_hander.process_instruction_id(
                'bundle4', [state_token2, side1_token1]):
            self.assertEqual(get_as_list(state1), [401])  # uncached
            self.assertEqual(get_as_list(state2), [402])  # uncached
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side1),
                             [203])  # cached on side1_token1
            self.assertEqual(get_as_list(side2), [403])  # uncached
            self.assertEqual(get_as_list(side2), [403])  # cached on bundle

        underlying_state.set_counter(500)
        with caching_state_hander.process_instruction_id(
                'bundle5', [state_token2, side1_token2]):
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(state2),
                             [402])  # cached on state token2
            self.assertEqual(get_as_list(state1),
                             [401])  # cached on state token2
            self.assertEqual(get_as_list(side1), [501])  # uncached
            self.assertEqual(get_as_list(side1),
                             [501])  # cached on side1_token2
            self.assertEqual(get_as_list(side2), [502])  # uncached
            self.assertEqual(get_as_list(side2), [502])  # cached on bundle