def test_extend_fetches_initial_state(self): coder = VarIntCoder() coder_impl = coder.get_impl() class UnderlyingStateHandler(object): """Simply returns an incremented counter as the state "value." """ def set_value(self, value): self._encoded_values = coder.encode(value) def get_raw(self, *args): return self._encoded_values, None def append_raw(self, _key, bytes): self._encoded_values += bytes def clear(self, *args): self._encoded_values = bytes() @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state_handler = UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) def get(): return list(handler.blocking_get(state, coder_impl, True)) def append(value): handler.extend(state, coder_impl, [value], True) def clear(): handler.clear(state, True) # Initialize state underlying_state_handler.set_value(42) with handler.process_instruction_id('bundle', [cache_token]): # Append without reading beforehand append(43) self.assertEqual(get(), [42, 43]) clear() self.assertEqual(get(), []) append(44) self.assertEqual(get(), [44])
def test_continuation_token(self): underlying_state_handler = self.UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) coder = VarIntCoder() state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) def get(materialize=True): result = handler.blocking_get(state, coder.get_impl()) return list(result) if materialize else result def get_type(): return type(get(materialize=False)) def append(*values): handler.extend(state, coder.get_impl(), values) def clear(): handler.clear(state) underlying_state_handler.set_continuations(True) underlying_state_handler.set_values([45, 46, 47], coder) with handler.process_instruction_id('bundle', [cache_token]): self.assertEqual(get_type(), CachingStateHandler.ContinuationIterable) self.assertEqual(get(), [45, 46, 47]) append(48, 49) self.assertEqual(get_type(), CachingStateHandler.ContinuationIterable) self.assertEqual(get(), [45, 46, 47, 48, 49]) clear() self.assertEqual(get_type(), list) self.assertEqual(get(), []) append(1) self.assertEqual(get(), [1]) append(2, 3) self.assertEqual(get(), [1, 2, 3]) clear() for i in range(1000): append(i) self.assertEqual(get_type(), list) self.assertEqual(get(), [i for i in range(1000)])
def test_append_clear_with_preexisting_state(self): state = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) cache_token = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) coder = VarIntCoder() underlying_state_handler = self.UnderlyingStateHandler() state_cache = statecache.StateCache(100) handler = sdk_worker.CachingStateHandler(state_cache, underlying_state_handler) def get(): return handler.blocking_get(state, coder.get_impl()) def append(iterable): handler.extend(state, coder.get_impl(), iterable) def clear(): handler.clear(state) # Initialize state underlying_state_handler.set_value(42, coder) with handler.process_instruction_id('bundle', [cache_token]): # Append without reading beforehand append([43]) self.assertEqual(get(), [42, 43]) clear() self.assertEqual(get(), []) append([44, 45]) self.assertEqual(get(), [44, 45]) append((46, 47)) self.assertEqual(get(), [44, 45, 46, 47]) clear() append(range(1000)) self.assertEqual(get(), list(range(1000)))
def __init__(self, unused_payload, # type: None state, # type: sdk_worker.StateHandler provision_info, # type: Optional[ExtendedProvisionInfo] worker_manager, # type: WorkerHandlerManager ): # type: (...) -> None super(EmbeddedWorkerHandler, self).__init__( self, data_plane.InMemoryDataChannel(), state, provision_info) self.control_conn = self # type: ignore # need Protocol to describe this self.data_conn = self.data_plane_handler state_cache = StateCache(STATE_CACHE_SIZE) self.bundle_processor_cache = sdk_worker.BundleProcessorCache( SingletonStateHandlerFactory( sdk_worker.CachingStateHandler(state_cache, state)), data_plane.InMemoryDataChannelFactory( self.data_plane_handler.inverse()), worker_manager._process_bundle_descriptors) self.worker = sdk_worker.SdkWorker( self.bundle_processor_cache, state_cache_metrics_fn=state_cache.get_monitoring_infos) self._uid_counter = 0
def test_caching(self): coder = VarIntCoder() coder_impl = coder.get_impl() class FakeUnderlyingState(object): """Simply returns an incremented counter as the state "value." """ def set_counter(self, n): self._counter = n def get_raw(self, *args): self._counter += 1 return coder.encode(self._counter), None @contextlib.contextmanager def process_instruction_id(self, bundle_id): yield underlying_state = FakeUnderlyingState() state_cache = statecache.StateCache(100) caching_state_hander = sdk_worker.CachingStateHandler( state_cache, underlying_state) state1 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state1')) state2 = beam_fn_api_pb2.StateKey( bag_user_state=beam_fn_api_pb2.StateKey.BagUserState( user_state_id='state2')) side1 = beam_fn_api_pb2.StateKey( multimap_side_input=beam_fn_api_pb2.StateKey.MultimapSideInput( transform_id='transform', side_input_id='side1')) side2 = beam_fn_api_pb2.StateKey( iterable_side_input=beam_fn_api_pb2.StateKey.IterableSideInput( transform_id='transform', side_input_id='side2')) state_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token1', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) state_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'state_token2', user_state=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. UserState()) side1_token1 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token1', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) side1_token2 = beam_fn_api_pb2.ProcessBundleRequest.CacheToken( token=b'side1_token2', side_input=beam_fn_api_pb2.ProcessBundleRequest.CacheToken. SideInput(transform_id='transform', side_input_id='side1')) def get_as_list(key): return list(caching_state_hander.blocking_get(key, coder_impl)) underlying_state.set_counter(100) with caching_state_hander.process_instruction_id('bundle1', []): self.assertEqual(get_as_list(state1), [101]) # uncached self.assertEqual(get_as_list(state2), [102]) # uncached self.assertEqual(get_as_list(state1), [101]) # cached on bundle self.assertEqual(get_as_list(side1), [103]) # uncached self.assertEqual(get_as_list(side2), [104]) # uncached underlying_state.set_counter(200) with caching_state_hander.process_instruction_id( 'bundle2', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # uncached self.assertEqual(get_as_list(state2), [202]) # uncached self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # uncached self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [204]) # uncached self.assertEqual(get_as_list(side2), [204]) # cached on bundle underlying_state.set_counter(300) with caching_state_hander.process_instruction_id( 'bundle3', [state_token1, side1_token1]): self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(state2), [202]) # cached on state token1 self.assertEqual(get_as_list(state1), [201]) # cached on state token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [301]) # uncached self.assertEqual(get_as_list(side2), [301]) # cached on bundle underlying_state.set_counter(400) with caching_state_hander.process_instruction_id( 'bundle4', [state_token2, side1_token1]): self.assertEqual(get_as_list(state1), [401]) # uncached self.assertEqual(get_as_list(state2), [402]) # uncached self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side1), [203]) # cached on side1_token1 self.assertEqual(get_as_list(side2), [403]) # uncached self.assertEqual(get_as_list(side2), [403]) # cached on bundle underlying_state.set_counter(500) with caching_state_hander.process_instruction_id( 'bundle5', [state_token2, side1_token2]): self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(state2), [402]) # cached on state token2 self.assertEqual(get_as_list(state1), [401]) # cached on state token2 self.assertEqual(get_as_list(side1), [501]) # uncached self.assertEqual(get_as_list(side1), [501]) # cached on side1_token2 self.assertEqual(get_as_list(side2), [502]) # uncached self.assertEqual(get_as_list(side2), [502]) # cached on bundle