def test_evict_all(self): cache = StateCache(5) cache.put("key", "cache_token", "value") cache.put("key2", "cache_token", "value2") self.assertEqual(len(cache), 2) cache.evict_all() self.assertEqual(len(cache), 0) self.assertEqual(cache.get("key", "cache_token"), None) self.assertEqual(cache.get("key2", "cache_token"), None)
def test_max_size(self): cache = StateCache(2) cache.put("key", "cache_token", "value") cache.put("key2", "cache_token", "value") self.assertEqual(len(cache), 2) cache.put("key2", "cache_token", "value") self.assertEqual(len(cache), 2) cache.put("key", "cache_token", "value") self.assertEqual(len(cache), 2)
class GrpcStateHandlerFactory(StateHandlerFactory): """A factory for ``GrpcStateHandler``. Caches the created channels by ``state descriptor url``. """ def __init__(self, state_cache_size, credentials=None): self._state_handler_cache = {} self._lock = threading.Lock() self._throwing_state_handler = ThrowingStateHandler() self._credentials = credentials self._state_cache = StateCache(state_cache_size) def create_state_handler(self, api_service_descriptor): if not api_service_descriptor: return self._throwing_state_handler url = api_service_descriptor.url if url not in self._state_handler_cache: with self._lock: if url not in self._state_handler_cache: # Options to have no limits (-1) on the size of the messages # received or sent over the data plane. The actual buffer size is # controlled in a layer above. options = [('grpc.max_receive_message_length', -1), ('grpc.max_send_message_length', -1)] if self._credentials is None: logging.info('Creating insecure state channel for %s.', url) grpc_channel = GRPCChannelFactory.insecure_channel( url, options=options) else: logging.info('Creating secure state channel for %s.', url) grpc_channel = GRPCChannelFactory.secure_channel( url, self._credentials, options=options) logging.info('State channel established.') # Add workerId to the grpc channel grpc_channel = grpc.intercept_channel( grpc_channel, WorkerIdInterceptor()) self._state_handler_cache[ url] = CachingMaterializingStateHandler( self._state_cache, GrpcStateHandler( beam_fn_api_pb2_grpc.BeamFnStateStub( grpc_channel))) return self._state_handler_cache[url] def close(self): logging.info('Closing all cached gRPC state handlers.') for _, state_handler in self._state_handler_cache.items(): state_handler.done() self._state_handler_cache.clear() self._state_cache.evict_all()
def test_overwrite(self): cache = StateCache(2) cache.put("key", "cache_token", "value") cache.put("key", "cache_token2", "value2") self.assertEqual(len(cache), 1) self.assertEqual(cache.get("key", "cache_token"), None) self.assertEqual(cache.get("key", "cache_token2"), "value2")
def __init__( self, control_address, worker_count, credentials=None, worker_id=None, # Caching is disabled by default state_cache_size=0, profiler_factory=None): self._alive = True self._worker_count = worker_count self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: logging.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: logging.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) logging.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) # workers for process/finalize bundle. self.workers = queue.Queue() # one worker for progress/split request. self.progress_worker = SdkWorker( self._bundle_processor_cache, profiler_factory=self._profiler_factory) # one thread is enough for getting the progress report. # Assumption: # Progress report generation should not do IO or wait on other resources. # Without wait, having multiple threads will not improve performance and # will only add complexity. self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1) # finalize and process share one thread pool. self._process_thread_pool = futures.ThreadPoolExecutor( max_workers=self._worker_count) self._responses = queue.Queue() self._process_bundle_queue = queue.Queue() self._unscheduled_process_bundle = {} logging.info('Initializing SDKHarness with %s workers.', self._worker_count)
def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str, unicode] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache) except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue( ) # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def test_put_get(self): cache = StateCache(5) cache.put("key", "cache_token", "value") self.assertEqual(len(cache), 1) self.assertEqual(cache.get("key", "cache_token"), "value") self.assertEqual(cache.get("key", "cache_token2"), None) with self.assertRaises(Exception): self.assertEqual(cache.get("key", None), None)
def __init__( self, control_address, credentials=None, worker_id=None, # Caching is disabled by default state_cache_size=0, profiler_factory=None): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor( max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue() _LOGGER.info( 'Initializing SDKHarness with unbounded number of workers.')
def __init__(self, unused_payload, # type: None state, # type: sdk_worker.StateHandler provision_info, # type: Optional[ExtendedProvisionInfo] worker_manager, # type: WorkerHandlerManager ): # type: (...) -> None super(EmbeddedWorkerHandler, self).__init__( self, data_plane.InMemoryDataChannel(), state, provision_info) self.control_conn = self # type: ignore # need Protocol to describe this self.data_conn = self.data_plane_handler state_cache = StateCache(STATE_CACHE_SIZE) self.bundle_processor_cache = sdk_worker.BundleProcessorCache( SingletonStateHandlerFactory( sdk_worker.CachingStateHandler(state_cache, state)), data_plane.InMemoryDataChannelFactory( self.data_plane_handler.inverse()), worker_manager._process_bundle_descriptors) self.worker = sdk_worker.SdkWorker( self.bundle_processor_cache, state_cache_metrics_fn=state_cache.get_monitoring_infos) self._uid_counter = 0
def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, profiler_factory=None # type: Optional[Callable[..., Profile]] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory(self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def __init__( self, control_address, # type: str credentials=None, # type: Optional[grpc.ChannelCredentials] worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # type: int # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, # type: int profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str] # Heap dump through status api is disabled by default enable_heap_dump=False, # type: bool ): # type: (...) -> None self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) options = [('grpc.max_receive_message_length', -1), ('grpc.max_send_message_length', -1)] if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address, options=options) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials, options=options) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory def default_factory(id): # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor return self._control_stub.GetProcessBundleDescriptor( beam_fn_api_pb2.GetProcessBundleDescriptorRequest( process_bundle_descriptor_id=id)) self._fns = KeyedDefaultDict(default_factory) # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache, enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler] except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common # thread_pool_executor.shared_unbounded_instance() to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance() self._responses = queue.Queue( ) # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def get_cache(size): cache = StateCache(size) cache.initialize_metrics() return cache
def __init__(self, state_cache_size, credentials=None): self._state_handler_cache = {} self._lock = threading.Lock() self._throwing_state_handler = ThrowingStateHandler() self._credentials = credentials self._state_cache = StateCache(state_cache_size)
def test_lru(self): cache = StateCache(5) cache.put("key", "cache_token", "value") cache.put("key2", "cache_token2", "value2") cache.put("key3", "cache_token", "value0") cache.put("key3", "cache_token", "value3") cache.put("key4", "cache_token4", "value4") cache.put("key5", "cache_token", "value0") cache.put("key5", "cache_token", ["value5"]) self.assertEqual(len(cache), 5) self.assertEqual(cache.get("key", "cache_token"), "value") self.assertEqual(cache.get("key2", "cache_token2"), "value2") self.assertEqual(cache.get("key3", "cache_token"), "value3") self.assertEqual(cache.get("key4", "cache_token4"), "value4") self.assertEqual(cache.get("key5", "cache_token"), ["value5"]) # insert another key to trigger cache eviction cache.put("key6", "cache_token2", "value7") self.assertEqual(len(cache), 5) # least recently used key should be gone ("key") self.assertEqual(cache.get("key", "cache_token"), None) # trigger a read on "key2" cache.get("key2", "cache_token") # insert another key to trigger cache eviction cache.put("key7", "cache_token", "value7") self.assertEqual(len(cache), 5) # least recently used key should be gone ("key3") self.assertEqual(cache.get("key3", "cache_token"), None) # trigger a put on "key2" cache.put("key2", "cache_token", "put") self.assertEqual(len(cache), 5) # insert another key to trigger cache eviction cache.put("key8", "cache_token", "value8") self.assertEqual(len(cache), 5) # least recently used key should be gone ("key4") self.assertEqual(cache.get("key4", "cache_token"), None) # make "key5" used by appending to it cache.extend("key5", "cache_token", ["another"]) # least recently used key should be gone ("key6") self.assertEqual(cache.get("key6", "cache_token"), None)
def test_is_cached_enabled(self): cache = StateCache(1) self.assertEqual(cache.is_cache_enabled(), True) cache = StateCache(0) self.assertEqual(cache.is_cache_enabled(), False)
def test_clear(self): cache = StateCache(5) cache.clear("new-key", "cache_token") cache.put("key", "cache_token", ["value"]) self.assertEqual(len(cache), 2) self.assertEqual(cache.get("new-key", "new_token"), None) self.assertEqual(cache.get("key", "cache_token"), ['value']) # test clear without existing key/token cache.clear("non-existing", "token") self.assertEqual(len(cache), 3) self.assertEqual(cache.get("non-existing", "token"), []) # test eviction in case the cache token changes cache.clear("new-key", "wrong_token") self.assertEqual(len(cache), 2) self.assertEqual(cache.get("new-key", "cache_token"), None) self.assertEqual(cache.get("new-key", "wrong_token"), None)
def test_extend(self): cache = StateCache(3) cache.put("key", "cache_token", ['val']) # test extend for existing key cache.extend("key", "cache_token", ['yet', 'another', 'val']) self.assertEqual(len(cache), 1) self.assertEqual(cache.get("key", "cache_token"), ['val', 'yet', 'another', 'val']) # test extend without existing key cache.extend("key2", "cache_token", ['another', 'val']) self.assertEqual(len(cache), 2) self.assertEqual(cache.get("key2", "cache_token"), ['another', 'val']) # test eviction in case the cache token changes cache.extend("key2", "new_token", ['new_value']) self.assertEqual(cache.get("key2", "new_token"), None) self.assertEqual(len(cache), 1)
def test_empty_cache_get(self): cache = StateCache(5) self.assertEqual(cache.get("key", 'cache_token'), None) with self.assertRaises(Exception): self.assertEqual(cache.get("key", None), None)