class FnApiWorkerStatusHandlerTest(unittest.TestCase): def setUp(self): self.num_request = 3 self.test_status_service = BeamFnStatusServicer(self.num_request) self.server = grpc.server( thread_pool_executor.shared_unbounded_instance()) beam_fn_api_pb2_grpc.add_BeamFnWorkerStatusServicer_to_server( self.test_status_service, self.server) self.test_port = self.server.add_insecure_port('[::]:0') self.server.start() self.url = 'localhost:%s' % self.test_port self.fn_status_handler = FnApiWorkerStatusHandler(self.url) def tearDown(self): self.server.stop(5) @timeout(5) def test_send_status_response(self): self.test_status_service.finished.acquire() while len( self.test_status_service.response_received) < self.num_request: self.test_status_service.finished.wait(1) self.test_status_service.finished.release() for response in self.test_status_service.response_received: self.assertIsNotNone(response.status_info) self.fn_status_handler.close() @timeout(5) @mock.patch('apache_beam.runners.worker.worker_status' '.FnApiWorkerStatusHandler.generate_status_response') def test_generate_error(self, mock_method): mock_method.side_effect = RuntimeError('error') self.test_status_service.finished.acquire() while len( self.test_status_service.response_received) < self.num_request: self.test_status_service.finished.wait(1) self.test_status_service.finished.release() for response in self.test_status_service.response_received: self.assertIsNotNone(response.error) self.fn_status_handler.close()
class SdkHarness(object): REQUEST_METHOD_PREFIX = '_request_' def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str, unicode] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache) except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue( ) # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.') def run(self): control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel) no_more_work = object() def get_responses(): # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse] while True: response = self._responses.get() if response is no_more_work: return yield response self._alive = True try: for work_request in control_stub.Control(get_responses()): _LOGGER.debug('Got work %s', work_request.instruction_id) request_type = work_request.WhichOneof('request') # Name spacing the request method with 'request_'. The called method # will be like self.request_register(request) getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( work_request) finally: self._alive = False _LOGGER.info('No more requests from control plane') _LOGGER.info('SDK Harness waiting for in-flight requests to complete') # Wait until existing requests are processed. self._worker_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. self._responses.put(no_more_work) # Stop all the workers and clean all the associated resources self._data_channel_factory.close() self._state_handler_factory.close() self._bundle_processor_cache.shutdown() if self._status_handler: self._status_handler.close() _LOGGER.info('Done consuming work.') def _execute(self, task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse] request # type: beam_fn_api_pb2.InstructionRequest ): # type: (...) -> None with statesampler.instruction_id(request.instruction_id): try: response = task() except Exception: # pylint: disable=broad-except traceback_string = traceback.format_exc() print(traceback_string, file=sys.stderr) _LOGGER.error( 'Error processing instruction %s. Original traceback is\n%s\n', request.instruction_id, traceback_string) response = beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=traceback_string) self._responses.put(response) def _request_register(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None # registration request is handled synchronously self._execute(lambda: self.create_worker().do_instruction(request), request) def _request_process_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_execute(request) def _request_process_bundle_split(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_progress(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_action(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): instruction_id = getattr( request, request.WhichOneof('request')).instruction_id # only process progress/split request when a bundle is in processing. if (instruction_id in self._bundle_processor_cache.active_bundle_processors): self._execute( lambda: self.create_worker().do_instruction(request), request) else: self._execute( lambda: beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=('Unknown process bundle instruction {}').format( instruction_id)), request) self._report_progress_executor.submit(task) def _request_finalize_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_execute(request) def _request_execute(self, request): def task(): self._execute( lambda: self.create_worker().do_instruction(request), request) self._worker_thread_pool.submit(task) _LOGGER.debug( "Currently using %s threads." % len(self._worker_thread_pool._workers)) def create_worker(self): return SdkWorker( self._bundle_processor_cache, state_cache_metrics_fn=self._state_cache.get_monitoring_infos, profiler_factory=self._profiler_factory)
class SdkHarness(object): REQUEST_METHOD_PREFIX = '_request_' def __init__( self, control_address, # type: str credentials=None, # type: Optional[grpc.ChannelCredentials] worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # type: int # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, # type: int profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str] # Heap dump through status api is disabled by default enable_heap_dump=False, # type: bool ): # type: (...) -> None self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) options = [('grpc.max_receive_message_length', -1), ('grpc.max_send_message_length', -1)] if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address, options=options) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials, options=options) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory def default_factory(id): # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor return self._control_stub.GetProcessBundleDescriptor( beam_fn_api_pb2.GetProcessBundleDescriptorRequest( process_bundle_descriptor_id=id)) self._fns = KeyedDefaultDict(default_factory) # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache, enable_heap_dump) # type: Optional[FnApiWorkerStatusHandler] except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common # thread_pool_executor.shared_unbounded_instance() to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance() self._responses = queue.Queue( ) # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.') def run(self): # type: () -> None self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub( self._control_channel) no_more_work = Sentinel.sentinel def get_responses(): # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse] while True: response = self._responses.get() if response is no_more_work: return yield response self._alive = True try: for work_request in self._control_stub.Control(get_responses()): _LOGGER.debug('Got work %s', work_request.instruction_id) request_type = work_request.WhichOneof('request') # Name spacing the request method with 'request_'. The called method # will be like self.request_register(request) getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( work_request) finally: self._alive = False _LOGGER.info('No more requests from control plane') _LOGGER.info('SDK Harness waiting for in-flight requests to complete') # Wait until existing requests are processed. self._worker_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. self._responses.put(no_more_work) # Stop all the workers and clean all the associated resources self._data_channel_factory.close() self._state_handler_factory.close() self._bundle_processor_cache.shutdown() if self._status_handler: self._status_handler.close() _LOGGER.info('Done consuming work.') def _execute( self, task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse] request # type: beam_fn_api_pb2.InstructionRequest ): # type: (...) -> None with statesampler.instruction_id(request.instruction_id): try: response = task() except Exception: # pylint: disable=broad-except traceback_string = traceback.format_exc() print(traceback_string, file=sys.stderr) _LOGGER.error( 'Error processing instruction %s. Original traceback is\n%s\n', request.instruction_id, traceback_string) response = beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=traceback_string) self._responses.put(response) def _request_register(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None # registration request is handled synchronously self._execute(lambda: self.create_worker().do_instruction(request), request) def _request_process_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._bundle_processor_cache.activate(request.instruction_id) self._request_execute(request) def _request_process_bundle_split(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_progress(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_action(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): # type: () -> None self._execute( lambda: self.create_worker().do_instruction(request), request) self._report_progress_executor.submit(task) def _request_finalize_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_execute(request) def _request_harness_monitoring_infos(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None process_wide_monitoring_infos = MetricsEnvironment.process_wide_container( ).to_runner_api_monitoring_infos(None).values() self._execute( lambda: beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, harness_monitoring_infos=( beam_fn_api_pb2.HarnessMonitoringInfosResponse( monitoring_data={ SHORT_ID_CACHE.get_short_id(info): info.payload for info in process_wide_monitoring_infos }))), request) def _request_monitoring_infos(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._execute( lambda: beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, monitoring_infos=beam_fn_api_pb2.MonitoringInfosMetadataResponse( monitoring_info=SHORT_ID_CACHE.get_infos( request.monitoring_infos.monitoring_info_id))), request) def _request_execute(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): # type: () -> None self._execute( lambda: self.create_worker().do_instruction(request), request) self._worker_thread_pool.submit(task) _LOGGER.debug( "Currently using %s threads." % len(self._worker_thread_pool._workers)) def create_worker(self): # type: () -> SdkWorker return SdkWorker( self._bundle_processor_cache, state_cache_metrics_fn=self._state_cache.get_monitoring_infos, profiler_factory=self._profiler_factory)