Exemplo n.º 1
0
class FnApiWorkerStatusHandlerTest(unittest.TestCase):
    def setUp(self):
        self.num_request = 3
        self.test_status_service = BeamFnStatusServicer(self.num_request)
        self.server = grpc.server(
            thread_pool_executor.shared_unbounded_instance())
        beam_fn_api_pb2_grpc.add_BeamFnWorkerStatusServicer_to_server(
            self.test_status_service, self.server)
        self.test_port = self.server.add_insecure_port('[::]:0')
        self.server.start()
        self.url = 'localhost:%s' % self.test_port
        self.fn_status_handler = FnApiWorkerStatusHandler(self.url)

    def tearDown(self):
        self.server.stop(5)

    @timeout(5)
    def test_send_status_response(self):
        self.test_status_service.finished.acquire()
        while len(
                self.test_status_service.response_received) < self.num_request:
            self.test_status_service.finished.wait(1)
        self.test_status_service.finished.release()
        for response in self.test_status_service.response_received:
            self.assertIsNotNone(response.status_info)
        self.fn_status_handler.close()

    @timeout(5)
    @mock.patch('apache_beam.runners.worker.worker_status'
                '.FnApiWorkerStatusHandler.generate_status_response')
    def test_generate_error(self, mock_method):
        mock_method.side_effect = RuntimeError('error')
        self.test_status_service.finished.acquire()
        while len(
                self.test_status_service.response_received) < self.num_request:
            self.test_status_service.finished.wait(1)
        self.test_status_service.finished.release()
        for response in self.test_status_service.response_received:
            self.assertIsNotNone(response.error)
        self.fn_status_handler.close()
Exemplo n.º 2
0
class SdkHarness(object):
  REQUEST_METHOD_PREFIX = '_request_'

  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               # time-based data buffering is disabled by default
               data_buffer_time_limit_ms=0,
               profiler_factory=None,  # type: Optional[Callable[..., Profile]]
               status_address=None,  # type: Optional[str, unicode]
               ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory
    self._fns = {}  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache)
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
    #  progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue(
    )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')

  def run(self):
    control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
    no_more_work = object()

    def get_responses():
      # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
      while True:
        response = self._responses.get()
        if response is no_more_work:
          return
        yield response

    self._alive = True

    try:
      for work_request in control_stub.Control(get_responses()):
        _LOGGER.debug('Got work %s', work_request.instruction_id)
        request_type = work_request.WhichOneof('request')
        # Name spacing the request method with 'request_'. The called method
        # will be like self.request_register(request)
        getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
            work_request)
    finally:
      self._alive = False

    _LOGGER.info('No more requests from control plane')
    _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
    # Wait until existing requests are processed.
    self._worker_thread_pool.shutdown()
    # get_responses may be blocked on responses.get(), but we need to return
    # control to its caller.
    self._responses.put(no_more_work)
    # Stop all the workers and clean all the associated resources
    self._data_channel_factory.close()
    self._state_handler_factory.close()
    self._bundle_processor_cache.shutdown()
    if self._status_handler:
      self._status_handler.close()
    _LOGGER.info('Done consuming work.')

  def _execute(self,
               task,  # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
               request  # type:  beam_fn_api_pb2.InstructionRequest
              ):
    # type: (...) -> None
    with statesampler.instruction_id(request.instruction_id):
      try:
        response = task()
      except Exception:  # pylint: disable=broad-except
        traceback_string = traceback.format_exc()
        print(traceback_string, file=sys.stderr)
        _LOGGER.error(
            'Error processing instruction %s. Original traceback is\n%s\n',
            request.instruction_id,
            traceback_string)
        response = beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id, error=traceback_string)
      self._responses.put(response)

  def _request_register(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    # registration request is handled synchronously
    self._execute(lambda: self.create_worker().do_instruction(request), request)

  def _request_process_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_process_bundle_split(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_progress(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_action(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None

    def task():
      instruction_id = getattr(
          request, request.WhichOneof('request')).instruction_id
      # only process progress/split request when a bundle is in processing.
      if (instruction_id in
          self._bundle_processor_cache.active_bundle_processors):
        self._execute(
            lambda: self.create_worker().do_instruction(request), request)
      else:
        self._execute(
            lambda: beam_fn_api_pb2.InstructionResponse(
                instruction_id=request.instruction_id,
                error=('Unknown process bundle instruction {}').format(
                    instruction_id)),
            request)

    self._report_progress_executor.submit(task)

  def _request_finalize_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_execute(self, request):
    def task():
      self._execute(
          lambda: self.create_worker().do_instruction(request), request)

    self._worker_thread_pool.submit(task)
    _LOGGER.debug(
        "Currently using %s threads." % len(self._worker_thread_pool._workers))

  def create_worker(self):
    return SdkWorker(
        self._bundle_processor_cache,
        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
        profiler_factory=self._profiler_factory)
Exemplo n.º 3
0
class SdkHarness(object):
  REQUEST_METHOD_PREFIX = '_request_'

  def __init__(
      self,
      control_address,  # type: str
      credentials=None,  # type: Optional[grpc.ChannelCredentials]
      worker_id=None,  # type: Optional[str]
      # Caching is disabled by default
      state_cache_size=0,  # type: int
      # time-based data buffering is disabled by default
      data_buffer_time_limit_ms=0,  # type: int
      profiler_factory=None,  # type: Optional[Callable[..., Profile]]
      status_address=None,  # type: Optional[str]
      # Heap dump through status api is disabled by default
      enable_heap_dump=False,  # type: bool
  ):
    # type: (...) -> None
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    options = [('grpc.max_receive_message_length', -1),
               ('grpc.max_send_message_length', -1)]
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address, options=options)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials, options=options)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory

    def default_factory(id):
      # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor
      return self._control_stub.GetProcessBundleDescriptor(
          beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
              process_bundle_descriptor_id=id))

    self._fns = KeyedDefaultDict(default_factory)
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache,
            enable_heap_dump)  # type: Optional[FnApiWorkerStatusHandler]
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common
    # thread_pool_executor.shared_unbounded_instance() to process bundle
    # progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance()
    self._responses = queue.Queue(
    )  # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')

  def run(self):
    # type: () -> None
    self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
        self._control_channel)
    no_more_work = Sentinel.sentinel

    def get_responses():
      # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
      while True:
        response = self._responses.get()
        if response is no_more_work:
          return
        yield response

    self._alive = True

    try:
      for work_request in self._control_stub.Control(get_responses()):
        _LOGGER.debug('Got work %s', work_request.instruction_id)
        request_type = work_request.WhichOneof('request')
        # Name spacing the request method with 'request_'. The called method
        # will be like self.request_register(request)
        getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
            work_request)
    finally:
      self._alive = False

    _LOGGER.info('No more requests from control plane')
    _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
    # Wait until existing requests are processed.
    self._worker_thread_pool.shutdown()
    # get_responses may be blocked on responses.get(), but we need to return
    # control to its caller.
    self._responses.put(no_more_work)
    # Stop all the workers and clean all the associated resources
    self._data_channel_factory.close()
    self._state_handler_factory.close()
    self._bundle_processor_cache.shutdown()
    if self._status_handler:
      self._status_handler.close()
    _LOGGER.info('Done consuming work.')

  def _execute(
      self,
      task,  # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
      request  # type:  beam_fn_api_pb2.InstructionRequest
  ):
    # type: (...) -> None
    with statesampler.instruction_id(request.instruction_id):
      try:
        response = task()
      except Exception:  # pylint: disable=broad-except
        traceback_string = traceback.format_exc()
        print(traceback_string, file=sys.stderr)
        _LOGGER.error(
            'Error processing instruction %s. Original traceback is\n%s\n',
            request.instruction_id,
            traceback_string)
        response = beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id, error=traceback_string)
      self._responses.put(response)

  def _request_register(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    # registration request is handled synchronously
    self._execute(lambda: self.create_worker().do_instruction(request), request)

  def _request_process_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._bundle_processor_cache.activate(request.instruction_id)
    self._request_execute(request)

  def _request_process_bundle_split(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_progress(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_action(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    def task():
      # type: () -> None
      self._execute(
          lambda: self.create_worker().do_instruction(request), request)

    self._report_progress_executor.submit(task)

  def _request_finalize_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_harness_monitoring_infos(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    process_wide_monitoring_infos = MetricsEnvironment.process_wide_container(
    ).to_runner_api_monitoring_infos(None).values()
    self._execute(
        lambda: beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id,
            harness_monitoring_infos=(
                beam_fn_api_pb2.HarnessMonitoringInfosResponse(
                    monitoring_data={
                        SHORT_ID_CACHE.get_short_id(info): info.payload
                        for info in process_wide_monitoring_infos
                    }))),
        request)

  def _request_monitoring_infos(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._execute(
        lambda: beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id,
            monitoring_infos=beam_fn_api_pb2.MonitoringInfosMetadataResponse(
                monitoring_info=SHORT_ID_CACHE.get_infos(
                    request.monitoring_infos.monitoring_info_id))),
        request)

  def _request_execute(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    def task():
      # type: () -> None
      self._execute(
          lambda: self.create_worker().do_instruction(request), request)

    self._worker_thread_pool.submit(task)
    _LOGGER.debug(
        "Currently using %s threads." % len(self._worker_thread_pool._workers))

  def create_worker(self):
    # type: () -> SdkWorker
    return SdkWorker(
        self._bundle_processor_cache,
        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
        profiler_factory=self._profiler_factory)