Exemple #1
0
class SdkHarness(object):
    REQUEST_METHOD_PREFIX = '_request_'

    def __init__(
            self,
            control_address,
            credentials=None,
            worker_id=None,
            # Caching is disabled by default
            state_cache_size=0,
            profiler_factory=None):
        self._alive = True
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            _LOGGER.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            _LOGGER.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        _LOGGER.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = {}
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)
        self._worker_thread_pool = UnboundedThreadPoolExecutor()
        self._responses = queue.Queue()
        _LOGGER.info(
            'Initializing SDKHarness with unbounded number of workers.')

    def run(self):
        control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
            self._control_channel)
        no_more_work = object()

        def get_responses():
            while True:
                response = self._responses.get()
                if response is no_more_work:
                    return
                yield response

        self._alive = True

        try:
            for work_request in control_stub.Control(get_responses()):
                _LOGGER.debug('Got work %s', work_request.instruction_id)
                request_type = work_request.WhichOneof('request')
                # Name spacing the request method with 'request_'. The called method
                # will be like self.request_register(request)
                getattr(self, SdkHarness.REQUEST_METHOD_PREFIX +
                        request_type)(work_request)
        finally:
            self._alive = False

        _LOGGER.info('No more requests from control plane')
        _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
        # Wait until existing requests are processed.
        self._worker_thread_pool.shutdown()
        # get_responses may be blocked on responses.get(), but we need to return
        # control to its caller.
        self._responses.put(no_more_work)
        # Stop all the workers and clean all the associated resources
        self._data_channel_factory.close()
        self._state_handler_factory.close()
        self._bundle_processor_cache.shutdown()
        _LOGGER.info('Done consuming work.')

    def _execute(self, task, request):
        try:
            response = task()
        except Exception:  # pylint: disable=broad-except
            traceback_string = traceback.format_exc()
            print(traceback_string, file=sys.stderr)
            _LOGGER.error(
                'Error processing instruction %s. Original traceback is\n%s\n',
                request.instruction_id, traceback_string)
            response = beam_fn_api_pb2.InstructionResponse(
                instruction_id=request.instruction_id, error=traceback_string)
        self._responses.put(response)

    def _request_register(self, request):
        self._request_execute(request)

    def _request_process_bundle(self, request):
        def task():
            self._execute(lambda: self.create_worker().do_instruction(request),
                          request)

        self._worker_thread_pool.submit(task)
        _LOGGER.debug("Currently using %s threads." %
                      len(self._worker_thread_pool._workers))

    def _request_process_bundle_split(self, request):
        self._request_process_bundle_action(request)

    def _request_process_bundle_progress(self, request):
        self._request_process_bundle_action(request)

    def _request_process_bundle_action(self, request):
        def task():
            instruction_id = getattr(
                request, request.WhichOneof('request')).instruction_id
            # only process progress/split request when a bundle is in processing.
            if (instruction_id
                    in self._bundle_processor_cache.active_bundle_processors):
                self._execute(
                    lambda: self.create_worker().do_instruction(request),
                    request)
            else:
                self._execute(
                    lambda: beam_fn_api_pb2.InstructionResponse(
                        instruction_id=request.instruction_id,
                        error=('Unknown process bundle instruction {}').format(
                            instruction_id)), request)

        self._worker_thread_pool.submit(task)

    def _request_finalize_bundle(self, request):
        self._request_execute(request)

    def _request_execute(self, request):
        def task():
            self._execute(lambda: self.create_worker().do_instruction(request),
                          request)

        self._worker_thread_pool.submit(task)

    def create_worker(self):
        return SdkWorker(
            self._bundle_processor_cache,
            state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
            profiler_factory=self._profiler_factory)
Exemple #2
0
class SdkHarness(object):
  REQUEST_METHOD_PREFIX = '_request_'

  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               # time-based data buffering is disabled by default
               data_buffer_time_limit_ms=0,
               profiler_factory=None,  # type: Optional[Callable[..., Profile]]
               status_address=None,  # type: Optional[str]
               ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory
    self._fns = KeyedDefaultDict(
        lambda id: self._control_stub.GetProcessBundleDescriptor(
            beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
                process_bundle_descriptor_id=id))
    )  # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache
        )  # type: Optional[FnApiWorkerStatusHandler]
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
    #  progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue(
    )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')

  def run(self):
    self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
        self._control_channel)
    no_more_work = object()

    def get_responses():
      # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
      while True:
        response = self._responses.get()
        if response is no_more_work:
          return
        yield response

    self._alive = True

    try:
      for work_request in self._control_stub.Control(get_responses()):
        _LOGGER.debug('Got work %s', work_request.instruction_id)
        request_type = work_request.WhichOneof('request')
        # Name spacing the request method with 'request_'. The called method
        # will be like self.request_register(request)
        getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
            work_request)
    finally:
      self._alive = False

    _LOGGER.info('No more requests from control plane')
    _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
    # Wait until existing requests are processed.
    self._worker_thread_pool.shutdown()
    # get_responses may be blocked on responses.get(), but we need to return
    # control to its caller.
    self._responses.put(no_more_work)
    # Stop all the workers and clean all the associated resources
    self._data_channel_factory.close()
    self._state_handler_factory.close()
    self._bundle_processor_cache.shutdown()
    if self._status_handler:
      self._status_handler.close()
    _LOGGER.info('Done consuming work.')

  def _execute(self,
               task,  # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
               request  # type:  beam_fn_api_pb2.InstructionRequest
              ):
    # type: (...) -> None
    with statesampler.instruction_id(request.instruction_id):
      try:
        response = task()
      except Exception:  # pylint: disable=broad-except
        traceback_string = traceback.format_exc()
        print(traceback_string, file=sys.stderr)
        _LOGGER.error(
            'Error processing instruction %s. Original traceback is\n%s\n',
            request.instruction_id,
            traceback_string)
        response = beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id, error=traceback_string)
      self._responses.put(response)

  def _request_register(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    # registration request is handled synchronously
    self._execute(lambda: self.create_worker().do_instruction(request), request)

  def _request_process_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_process_bundle_split(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_progress(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_action(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None

    def task():
      instruction_id = getattr(
          request, request.WhichOneof('request')).instruction_id
      # only process progress/split request when a bundle is in processing.
      if (instruction_id in
          self._bundle_processor_cache.active_bundle_processors):
        self._execute(
            lambda: self.create_worker().do_instruction(request), request)
      else:
        self._execute(
            lambda: beam_fn_api_pb2.InstructionResponse(
                instruction_id=request.instruction_id,
                error=('Unknown process bundle instruction {}').format(
                    instruction_id)),
            request)

    self._report_progress_executor.submit(task)

  def _request_finalize_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_execute(self, request):
    def task():
      self._execute(
          lambda: self.create_worker().do_instruction(request), request)

    self._worker_thread_pool.submit(task)
    _LOGGER.debug(
        "Currently using %s threads." % len(self._worker_thread_pool._workers))

  def create_worker(self):
    return SdkWorker(
        self._bundle_processor_cache,
        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
        profiler_factory=self._profiler_factory)