def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str, unicode] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache) except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue( ) # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def run(self): logging_server = grpc.server(UnboundedThreadPoolExecutor()) logging_port = logging_server.add_insecure_port('[::]:0') logging_server.start() logging_servicer = BeamFnLoggingServicer() beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server( logging_servicer, logging_server) logging_descriptor = text_format.MessageToString( endpoints_pb2.ApiServiceDescriptor(url='localhost:%s' % logging_port)) control_descriptor = text_format.MessageToString( endpoints_pb2.ApiServiceDescriptor(url=self._control_address)) env_dict = dict(os.environ, CONTROL_API_SERVICE_DESCRIPTOR=control_descriptor, LOGGING_API_SERVICE_DESCRIPTOR=logging_descriptor) # only add worker_id when it is set. if self._worker_id: env_dict['WORKER_ID'] = self._worker_id with fn_api_runner.SUBPROCESS_LOCK: p = subprocess.Popen(self._worker_command_line, shell=True, env=env_dict) try: p.wait() if p.returncode: raise RuntimeError( 'Worker subprocess exited with return code %s' % p.returncode) finally: if p.poll() is None: p.kill() logging_server.stop(0)
def start( cls, use_process=False, port=0, state_cache_size=0, data_buffer_time_limit_ms=-1, container_executable=None # type: Optional[str] ): # type: (...) -> Tuple[str, grpc.Server] worker_server = grpc.server(UnboundedThreadPoolExecutor()) worker_address = 'localhost:%s' % worker_server.add_insecure_port( '[::]:%s' % port) worker_pool = cls(use_process=use_process, container_executable=container_executable, state_cache_size=state_cache_size, data_buffer_time_limit_ms=data_buffer_time_limit_ms) beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server( worker_pool, worker_server) worker_server.start() # Register to kill the subprocesses on exit. def kill_worker_processes(): for worker_process in worker_pool._worker_processes.values(): worker_process.kill() atexit.register(kill_worker_processes) return worker_address, worker_server
def _check_fn_registration_multi_request(self, *args): """Check the function registration calls to the sdk_harness. Args: tuple of request_count, number of process_bundles per request and workers counts to process the request. """ for (request_count, process_bundles_per_request) in args: requests = [] process_bundle_descriptors = [] for i in range(request_count): pbd = self._get_process_bundles(i, process_bundles_per_request) process_bundle_descriptors.extend(pbd) requests.append( beam_fn_api_pb2.InstructionRequest( instruction_id=str(i), register=beam_fn_api_pb2.RegisterRequest( process_bundle_descriptor=process_bundle_descriptors))) test_controller = BeamFnControlServicer(requests) server = grpc.server(UnboundedThreadPoolExecutor()) beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server( test_controller, server) test_port = server.add_insecure_port("[::]:0") server.start() harness = sdk_worker.SdkHarness( "localhost:%s" % test_port, state_cache_size=100) harness.run() self.assertEqual(harness._bundle_processor_cache.fns, {item.id: item for item in process_bundle_descriptors})
def _stage_files(self, files): """Utility method to stage files. Args: files: a list of tuples of the form [(local_name, remote_name),...] describing the name of the artifacts in local temp folder and desired name in staging location. """ server = grpc.server(UnboundedThreadPoolExecutor()) staging_service = TestLocalFileSystemArtifactStagingServiceServicer( self._remote_dir) beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( staging_service, server) test_port = server.add_insecure_port('[::]:0') server.start() stager = portable_stager.PortableStager( artifact_service_channel=grpc.insecure_channel('localhost:%s' % test_port), staging_session_token='token') for from_file, to_file in files: stager.stage_artifact(local_path_to_artifact=os.path.join( self._temp_dir, from_file), artifact_name=to_file) stager.commit_manifest() return staging_service.manifest.artifact, staging_service.retrieval_tokens
def _grpc_data_channel_test(self, time_based_flush=False): if time_based_flush: data_servicer = data_plane.BeamFnDataServicer( data_buffer_time_limit_ms=100) else: data_servicer = data_plane.BeamFnDataServicer() worker_id = 'worker_0' data_channel_service = \ data_servicer.get_conn_by_worker_id(worker_id) server = grpc.server(UnboundedThreadPoolExecutor()) beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server( data_servicer, server) test_port = server.add_insecure_port('[::]:0') server.start() grpc_channel = grpc.insecure_channel('localhost:%s' % test_port) # Add workerId to the grpc channel grpc_channel = grpc.intercept_channel( grpc_channel, WorkerIdInterceptor(worker_id)) data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel) if time_based_flush: data_channel_client = data_plane.GrpcClientDataChannel( data_channel_stub, data_buffer_time_limit_ms=100) else: data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub) try: self._data_channel_test( data_channel_service, data_channel_client, time_based_flush) finally: data_channel_client.close() data_channel_service.close() data_channel_client.wait() data_channel_service.wait()
def __init__( self, control_address, credentials=None, worker_id=None, # Caching is disabled by default state_cache_size=0, profiler_factory=None): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor( max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue() _LOGGER.info( 'Initializing SDKHarness with unbounded number of workers.')
def test_concurrent_requests(self): num_sessions = 7 artifacts = collections.defaultdict(list) def name(index): # Overlapping names across sessions. return 'name%d' % (index // num_sessions) def session(index): return 'session%d' % (index % num_sessions) def delayed_data(data, index, max_msecs=1): time.sleep(max_msecs / 1000.0 * random.random()) return ('%s_%d' % (data, index)).encode('ascii') def put(index): artifacts[session(index)].append( beam_artifact_api_pb2.ArtifactMetadata(name=name(index))) self._service.PutArtifact([ self.put_metadata(session(index), name(index)), self.put_data(delayed_data('a', index)), self.put_data(delayed_data('b' * 20, index, 2)) ]) return session(index) def commit(session): return session, self._service.CommitManifest( beam_artifact_api_pb2.CommitManifestRequest( staging_session_token=session, manifest=beam_artifact_api_pb2.Manifest( artifact=artifacts[session]))).retrieval_token def check(index): self.assertEqual( delayed_data('a', index) + delayed_data('b' * 20, index, 2), self.retrieve_artifact(self._service, tokens[session(index)], name(index))) # pylint: disable=range-builtin-not-iterating pool = UnboundedThreadPoolExecutor() sessions = set(pool.map(put, range(100))) tokens = dict(pool.map(commit, sessions)) # List forces materialization. _ = list(pool.map(check, range(100)))
def process_bundle( self, inputs, # type: Mapping[str, execution.PartitionableBuffer] expected_outputs, # type: DataOutput fired_timers, # type: Mapping[Tuple[str, str], execution.PartitionableBuffer] expected_output_timers, # type: Dict[Tuple[str, str], str] dry_run=False, ): # type: (...) -> BundleProcessResult part_inputs = [{} for _ in range(self._num_workers) ] # type: List[Dict[str, List[bytes]]] # Timers are only executed on the first worker # TODO(BEAM-9741): Split timers to multiple workers timer_inputs = [ fired_timers if i == 0 else {} for i in range(self._num_workers) ] for name, input in inputs.items(): for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part merged_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] split_result_list = [ ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] def execute(part_map_input_timers): # type: (...) -> BundleProcessResult part_map, input_timers = part_map_input_timers bundle_manager = BundleManager( self.bundle_context_manager, self._progress_frequency, cache_token_generator=self._cache_token_generator) return bundle_manager.process_bundle(part_map, expected_outputs, input_timers, expected_output_timers, dry_run) with UnboundedThreadPoolExecutor() as executor: for result, split_result in executor.map( execute, zip( part_inputs, # pylint: disable=zip-builtin-not-iterating timer_inputs)): split_result_list += split_result if merged_result is None: merged_result = result else: merged_result = beam_fn_api_pb2.InstructionResponse( process_bundle=beam_fn_api_pb2.ProcessBundleResponse( monitoring_infos=monitoring_infos.consolidate( itertools.chain( result.process_bundle.monitoring_infos, merged_result.process_bundle. monitoring_infos))), error=result.error or merged_result.error) assert merged_result is not None return merged_result, split_result_list
def setUp(self): self.num_request = 3 self.test_status_service = BeamFnStatusServicer(self.num_request) self.server = grpc.server(UnboundedThreadPoolExecutor()) beam_fn_api_pb2_grpc.add_BeamFnWorkerStatusServicer_to_server( self.test_status_service, self.server) self.test_port = self.server.add_insecure_port('[::]:0') self.server.start() self.url = 'localhost:%s' % self.test_port self.fn_status_handler = FnApiWorkerStatusHandler(self.url)
def test_shutdown_with_slow_workers(self): futures = [] with UnboundedThreadPoolExecutor() as executor: for _ in range(0, 5): futures.append(executor.submit(self.append_and_sleep, 1)) for future in futures: future.result(timeout=10) with self._lock: self.assertEqual(5, len(self._worker_idents))
def start_grpc_server(self, port=0): self._server = grpc.server(UnboundedThreadPoolExecutor()) port = self._server.add_insecure_port('localhost:%d' % port) beam_job_api_pb2_grpc.add_JobServiceServicer_to_server( self, self._server) beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( self._artifact_service, self._server) self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor( url='localhost:%d' % port) self._server.start() _LOGGER.info('Grpc server started on port %s', port) return port
def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, profiler_factory=None # type: Optional[Callable[..., Profile]] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory(self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue() # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
def start_grpc_server(self, port=0): self._server = grpc.server(UnboundedThreadPoolExecutor()) port = self._server.add_insecure_port( '%s:%d' % (self.get_bind_address(), port)) beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server) beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( self._artifact_service, self._server) hostname = self.get_service_address() self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor( url='%s:%d' % (hostname, port)) self._server.start() _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port)) return port
def test_exception_propagation(self): with UnboundedThreadPoolExecutor() as executor: future = executor.submit(self.raise_error, 'footest') try: future.result() except Exception: message = traceback.format_exc() else: raise AssertionError('expected exception not raised') self.assertIn('footest', message) self.assertIn('raise_error', message)
def setUp(self): self.test_logging_service = BeamFnLoggingServicer() self.server = grpc.server(UnboundedThreadPoolExecutor()) beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server( self.test_logging_service, self.server) self.test_port = self.server.add_insecure_port('[::]:0') self.server.start() self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor() self.logging_service_descriptor.url = 'localhost:%s' % self.test_port self.fn_log_handler = log_handler.FnApiLogRecordHandler( self.logging_service_descriptor) logging.getLogger().setLevel(logging.INFO) logging.getLogger().addHandler(self.fn_log_handler)
def test_worker_reuse(self): futures = [] with UnboundedThreadPoolExecutor() as executor: for _ in range(0, 5): futures.append(executor.submit(self.append_and_sleep, 0.01)) time.sleep(3) for _ in range(0, 5): futures.append(executor.submit(self.append_and_sleep, 0.01)) for future in futures: future.result(timeout=10) with self._lock: self.assertEqual(10, len(self._worker_idents)) self.assertTrue(len(set(self._worker_idents)) < 10)
def process_bundle(self, inputs, # type: Mapping[str, PartitionableBuffer] expected_outputs, # type: DataOutput fired_timers, # type: Mapping[Tuple[str, str], PartitionableBuffer] expected_output_timers # type: Dict[Tuple[str, str], str] ): # type: (...) -> BundleProcessResult part_inputs = [{} for _ in range(self._num_workers) ] # type: List[Dict[str, List[bytes]]] for name, input in inputs.items(): for ix, part in enumerate(input.partition(self._num_workers)): part_inputs[ix][name] = part merged_result = None # type: Optional[beam_fn_api_pb2.InstructionResponse] split_result_list = [ ] # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse] def execute(part_map): # type: (...) -> BundleProcessResult bundle_manager = BundleManager( self._worker_handler_list, self._get_buffer, self._get_input_coder_impl, self._bundle_descriptor, self._progress_frequency, self._registered, cache_token_generator=self._cache_token_generator) return bundle_manager.process_bundle( part_map, expected_outputs, fired_timers, expected_output_timers) with UnboundedThreadPoolExecutor() as executor: for result, split_result in executor.map(execute, part_inputs): split_result_list += split_result if merged_result is None: merged_result = result else: merged_result = beam_fn_api_pb2.InstructionResponse( process_bundle=beam_fn_api_pb2.ProcessBundleResponse( monitoring_infos=monitoring_infos.consolidate( itertools.chain( result.process_bundle.monitoring_infos, merged_result.process_bundle.monitoring_infos))), error=result.error or merged_result.error) assert merged_result is not None return merged_result, split_result_list
def test_with_grpc(self): server = grpc.server(UnboundedThreadPoolExecutor()) try: beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server( self._service, server) beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server( self._service, server) port = server.add_insecure_port('[::]:0') server.start() channel = grpc.insecure_channel('localhost:%d' % port) self._run_staging( beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub(channel), beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub( channel)) channel.close() finally: server.stop(1)
def main(unused_argv): parser = argparse.ArgumentParser() parser.add_argument('-p', '--port', type=int, help='port on which to serve the job api') options = parser.parse_args() global server server = grpc.server(UnboundedThreadPoolExecutor()) beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( expansion_service.ExpansionServiceServicer(PipelineOptions()), server ) server.add_insecure_port('localhost:{}'.format(options.port)) server.start() _LOGGER.info('Listening for expansion requests at %d', options.port) signal.signal(signal.SIGTERM, cleanup) signal.signal(signal.SIGINT, cleanup) # blocking main thread forever. signal.pause()
def start(cls, use_process=False, port=0, state_cache_size=0, container_executable=None): worker_server = grpc.server(UnboundedThreadPoolExecutor()) worker_address = 'localhost:%s' % worker_server.add_insecure_port( '[::]:%s' % port) worker_pool = cls(use_process=use_process, container_executable=container_executable, state_cache_size=state_cache_size) beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server( worker_pool, worker_server) worker_server.start() # Register to kill the subprocesses on exit. def kill_worker_processes(): for worker_process in worker_pool._worker_processes.values(): worker_process.kill() atexit.register(kill_worker_processes) return worker_address, worker_server
def main(unused_argv): parser = argparse.ArgumentParser() parser.add_argument('-p', '--port', type=int, help='port on which to serve the job api') options = parser.parse_args() global server server = grpc.server(UnboundedThreadPoolExecutor()) # DOCKER SDK Harness beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( expansion_service.ExpansionServiceServicer( PipelineOptions([ "--experiments", "beam_fn_api", "--sdk_location", "container" ])), server) # PROCESS SDK Harness # beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server( # expansion_service.ExpansionServiceServicer( # PipelineOptions.from_dictionary({ # 'environment_type': 'PROCESS', # 'environment_config': '{"command": "sdks/python/container/build/target/launcher/darwin_amd64/boot"}', # 'experiments': 'beam_fn_api', # 'sdk_location': 'container', # }) # ), server # ) server.add_insecure_port('localhost:{}'.format(options.port)) server.start() _LOGGER.info('Listening for expansion requests at %d', options.port) signal.signal(signal.SIGTERM, cleanup) signal.signal(signal.SIGINT, cleanup) # blocking main thread forever. signal.pause()
def test_shutdown_with_no_workers(self): with UnboundedThreadPoolExecutor(): pass
def test_map(self): with UnboundedThreadPoolExecutor() as executor: executor.map(self.append_and_sleep, itertools.repeat(0.01, 5)) with self._lock: self.assertEqual(5, len(self._worker_idents))
class SdkHarness(object): REQUEST_METHOD_PREFIX = '_request_' def __init__( self, control_address, credentials=None, worker_id=None, # Caching is disabled by default state_cache_size=0, profiler_factory=None): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = {} # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue() _LOGGER.info( 'Initializing SDKHarness with unbounded number of workers.') def run(self): control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub( self._control_channel) no_more_work = object() def get_responses(): while True: response = self._responses.get() if response is no_more_work: return yield response self._alive = True try: for work_request in control_stub.Control(get_responses()): _LOGGER.debug('Got work %s', work_request.instruction_id) request_type = work_request.WhichOneof('request') # Name spacing the request method with 'request_'. The called method # will be like self.request_register(request) getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(work_request) finally: self._alive = False _LOGGER.info('No more requests from control plane') _LOGGER.info('SDK Harness waiting for in-flight requests to complete') # Wait until existing requests are processed. self._worker_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. self._responses.put(no_more_work) # Stop all the workers and clean all the associated resources self._data_channel_factory.close() self._state_handler_factory.close() self._bundle_processor_cache.shutdown() _LOGGER.info('Done consuming work.') def _execute(self, task, request): try: response = task() except Exception: # pylint: disable=broad-except traceback_string = traceback.format_exc() print(traceback_string, file=sys.stderr) _LOGGER.error( 'Error processing instruction %s. Original traceback is\n%s\n', request.instruction_id, traceback_string) response = beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=traceback_string) self._responses.put(response) def _request_register(self, request): self._request_execute(request) def _request_process_bundle(self, request): def task(): self._execute(lambda: self.create_worker().do_instruction(request), request) self._worker_thread_pool.submit(task) _LOGGER.debug("Currently using %s threads." % len(self._worker_thread_pool._workers)) def _request_process_bundle_split(self, request): self._request_process_bundle_action(request) def _request_process_bundle_progress(self, request): self._request_process_bundle_action(request) def _request_process_bundle_action(self, request): def task(): instruction_id = getattr( request, request.WhichOneof('request')).instruction_id # only process progress/split request when a bundle is in processing. if (instruction_id in self._bundle_processor_cache.active_bundle_processors): self._execute( lambda: self.create_worker().do_instruction(request), request) else: self._execute( lambda: beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=('Unknown process bundle instruction {}').format( instruction_id)), request) self._worker_thread_pool.submit(task) def _request_finalize_bundle(self, request): self._request_execute(request) def _request_execute(self, request): def task(): self._execute(lambda: self.create_worker().do_instruction(request), request) self._worker_thread_pool.submit(task) def create_worker(self): return SdkWorker( self._bundle_processor_cache, state_cache_metrics_fn=self._state_cache.get_monitoring_infos, profiler_factory=self._profiler_factory)
class SdkHarness(object): REQUEST_METHOD_PREFIX = '_request_' def __init__(self, control_address, # type: str credentials=None, worker_id=None, # type: Optional[str] # Caching is disabled by default state_cache_size=0, # time-based data buffering is disabled by default data_buffer_time_limit_ms=0, profiler_factory=None, # type: Optional[Callable[..., Profile]] status_address=None, # type: Optional[str] ): self._alive = True self._worker_index = 0 self._worker_id = worker_id self._state_cache = StateCache(state_cache_size) if credentials is None: _LOGGER.info('Creating insecure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.insecure_channel( control_address) else: _LOGGER.info('Creating secure control channel for %s.', control_address) self._control_channel = GRPCChannelFactory.secure_channel( control_address, credentials) grpc.channel_ready_future(self._control_channel).result(timeout=60) _LOGGER.info('Control channel established.') self._control_channel = grpc.intercept_channel( self._control_channel, WorkerIdInterceptor(self._worker_id)) self._data_channel_factory = data_plane.GrpcClientDataChannelFactory( credentials, self._worker_id, data_buffer_time_limit_ms) self._state_handler_factory = GrpcStateHandlerFactory( self._state_cache, credentials) self._profiler_factory = profiler_factory self._fns = KeyedDefaultDict( lambda id: self._control_stub.GetProcessBundleDescriptor( beam_fn_api_pb2.GetProcessBundleDescriptorRequest( process_bundle_descriptor_id=id)) ) # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor] # BundleProcessor cache across all workers. self._bundle_processor_cache = BundleProcessorCache( state_handler_factory=self._state_handler_factory, data_channel_factory=self._data_channel_factory, fns=self._fns) if status_address: try: self._status_handler = FnApiWorkerStatusHandler( status_address, self._bundle_processor_cache ) # type: Optional[FnApiWorkerStatusHandler] except Exception: traceback_string = traceback.format_exc() _LOGGER.warning( 'Error creating worker status request handler, ' 'skipping status report. Trace back: %s' % traceback_string) else: self._status_handler = None # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle # progress once dataflow runner's excessive progress polling is removed. self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1) self._worker_thread_pool = UnboundedThreadPoolExecutor() self._responses = queue.Queue( ) # type: queue.Queue[beam_fn_api_pb2.InstructionResponse] _LOGGER.info('Initializing SDKHarness with unbounded number of workers.') def run(self): self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub( self._control_channel) no_more_work = object() def get_responses(): # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse] while True: response = self._responses.get() if response is no_more_work: return yield response self._alive = True try: for work_request in self._control_stub.Control(get_responses()): _LOGGER.debug('Got work %s', work_request.instruction_id) request_type = work_request.WhichOneof('request') # Name spacing the request method with 'request_'. The called method # will be like self.request_register(request) getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)( work_request) finally: self._alive = False _LOGGER.info('No more requests from control plane') _LOGGER.info('SDK Harness waiting for in-flight requests to complete') # Wait until existing requests are processed. self._worker_thread_pool.shutdown() # get_responses may be blocked on responses.get(), but we need to return # control to its caller. self._responses.put(no_more_work) # Stop all the workers and clean all the associated resources self._data_channel_factory.close() self._state_handler_factory.close() self._bundle_processor_cache.shutdown() if self._status_handler: self._status_handler.close() _LOGGER.info('Done consuming work.') def _execute(self, task, # type: Callable[[], beam_fn_api_pb2.InstructionResponse] request # type: beam_fn_api_pb2.InstructionRequest ): # type: (...) -> None with statesampler.instruction_id(request.instruction_id): try: response = task() except Exception: # pylint: disable=broad-except traceback_string = traceback.format_exc() print(traceback_string, file=sys.stderr) _LOGGER.error( 'Error processing instruction %s. Original traceback is\n%s\n', request.instruction_id, traceback_string) response = beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=traceback_string) self._responses.put(response) def _request_register(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None # registration request is handled synchronously self._execute(lambda: self.create_worker().do_instruction(request), request) def _request_process_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_execute(request) def _request_process_bundle_split(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_progress(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_process_bundle_action(request) def _request_process_bundle_action(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None def task(): instruction_id = getattr( request, request.WhichOneof('request')).instruction_id # only process progress/split request when a bundle is in processing. if (instruction_id in self._bundle_processor_cache.active_bundle_processors): self._execute( lambda: self.create_worker().do_instruction(request), request) else: self._execute( lambda: beam_fn_api_pb2.InstructionResponse( instruction_id=request.instruction_id, error=('Unknown process bundle instruction {}').format( instruction_id)), request) self._report_progress_executor.submit(task) def _request_finalize_bundle(self, request): # type: (beam_fn_api_pb2.InstructionRequest) -> None self._request_execute(request) def _request_execute(self, request): def task(): self._execute( lambda: self.create_worker().do_instruction(request), request) self._worker_thread_pool.submit(task) _LOGGER.debug( "Currently using %s threads." % len(self._worker_thread_pool._workers)) def create_worker(self): return SdkWorker( self._bundle_processor_cache, state_cache_metrics_fn=self._state_cache.get_monitoring_infos, profiler_factory=self._profiler_factory)
def __init__( self, state, # type: StateServicer provision_info, # type: Optional[ExtendedProvisionInfo] worker_manager, # type: WorkerHandlerManager ): # type: (...) -> None self.state = state self.provision_info = provision_info self.control_server = grpc.server(UnboundedThreadPoolExecutor()) self.control_port = self.control_server.add_insecure_port('[::]:0') self.control_address = 'localhost:%s' % self.control_port # Options to have no limits (-1) on the size of the messages # received or sent over the data plane. The actual buffer size # is controlled in a layer above. no_max_message_sizes = [("grpc.max_receive_message_length", -1), ("grpc.max_send_message_length", -1)] self.data_server = grpc.server(UnboundedThreadPoolExecutor(), options=no_max_message_sizes) self.data_port = self.data_server.add_insecure_port('[::]:0') self.state_server = grpc.server(UnboundedThreadPoolExecutor(), options=no_max_message_sizes) self.state_port = self.state_server.add_insecure_port('[::]:0') self.control_handler = BeamFnControlServicer(worker_manager) beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server( self.control_handler, self.control_server) # If we have provision info, serve these off the control port as well. if self.provision_info: if self.provision_info.provision_info: beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server( BasicProvisionService(self.provision_info.provision_info, worker_manager), self.control_server) if self.provision_info.artifact_staging_dir: service = artifact_service.BeamFilesystemArtifactService( self.provision_info.artifact_staging_dir ) # type: beam_artifact_api_pb2_grpc.LegacyArtifactRetrievalServiceServicer else: service = EmptyArtifactRetrievalService() beam_artifact_api_pb2_grpc.add_LegacyArtifactRetrievalServiceServicer_to_server( service, self.control_server) self.data_plane_handler = data_plane.BeamFnDataServicer( DATA_BUFFER_TIME_LIMIT_MS) beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server( self.data_plane_handler, self.data_server) beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server( GrpcStateServicer(state), self.state_server) self.logging_server = grpc.server(UnboundedThreadPoolExecutor(), options=no_max_message_sizes) self.logging_port = self.logging_server.add_insecure_port('[::]:0') beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server( BasicLoggingService(), self.logging_server) _LOGGER.info('starting control server on port %s', self.control_port) _LOGGER.info('starting data server on port %s', self.data_port) _LOGGER.info('starting state server on port %s', self.state_port) _LOGGER.info('starting logging server on port %s', self.logging_port) self.logging_server.start() self.state_server.start() self.data_server.start() self.control_server.start()