Example #1
0
 def create_state_handler(self, api_service_descriptor):
     if not api_service_descriptor:
         return self._throwing_state_handler
     url = api_service_descriptor.url
     if url not in self._state_handler_cache:
         with self._lock:
             if url not in self._state_handler_cache:
                 # Options to have no limits (-1) on the size of the messages
                 # received or sent over the data plane. The actual buffer size is
                 # controlled in a layer above.
                 options = [('grpc.max_receive_message_length', -1),
                            ('grpc.max_send_message_length', -1)]
                 if self._credentials is None:
                     _LOGGER.info('Creating insecure state channel for %s.',
                                  url)
                     grpc_channel = GRPCChannelFactory.insecure_channel(
                         url, options=options)
                 else:
                     _LOGGER.info('Creating secure state channel for %s.',
                                  url)
                     grpc_channel = GRPCChannelFactory.secure_channel(
                         url, self._credentials, options=options)
                 _LOGGER.info('State channel established.')
                 # Add workerId to the grpc channel
                 grpc_channel = grpc.intercept_channel(
                     grpc_channel, WorkerIdInterceptor())
                 self._state_handler_cache[url] = CachingStateHandler(
                     self._state_cache,
                     GrpcStateHandler(
                         beam_fn_api_pb2_grpc.BeamFnStateStub(
                             grpc_channel)))
     return self._state_handler_cache[url]
Example #2
0
  def run(self):
    control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(self._control_channel)
    no_more_work = object()

    # Create workers
    for _ in range(self._worker_count):
      state_handler = GrpcStateHandler(
          beam_fn_api_pb2_grpc.BeamFnStateStub(self._control_channel))
      state_handler.start()
      # SdkHarness manage function registration and share self._fns with all
      # the workers. This is needed because function registration (register)
      # and exceution(process_bundle) are send over different request and we
      # do not really know which woker is going to process bundle
      # for a function till we get process_bundle request. Moreover
      # same function is reused by different process bundle calls and
      # potentially get executed by different worker. Hence we need a
      # centralized function list shared among all the workers.
      self.workers.put(
          SdkWorker(
              state_handler=state_handler,
              data_channel_factory=self._data_channel_factory,
              fns=self._fns))

    def get_responses():
      while True:
        response = self._responses.get()
        if response is no_more_work:
          return
        yield response

    for work_request in control_stub.Control(get_responses()):
      logging.info('Got work %s', work_request.instruction_id)
      request_type = work_request.WhichOneof('request')
      # Name spacing the request method with 'request_'. The called method
      # will be like self.request_register(request)
      getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
          work_request)

    logging.info('No more requests from control plane')
    logging.info('SDK Harness waiting for in-flight requests to complete')
    # Wait until existing requests are processed.
    self._progress_thread_pool.shutdown()
    self._process_thread_pool.shutdown()
    # get_responses may be blocked on responses.get(), but we need to return
    # control to its caller.
    self._responses.put(no_more_work)
    self._data_channel_factory.close()
    # Stop all the workers and clean all the associated resources
    for worker in self.workers.queue:
      worker.state_handler.done()
    logging.info('Done consuming work.')
Example #3
0
 def create_state_handler(self, api_service_descriptor):
     if not api_service_descriptor:
         return self._throwing_state_handler
     url = api_service_descriptor.url
     if url not in self._state_handler_cache:
         with self._lock:
             if url not in self._state_handler_cache:
                 logging.info('Creating channel for %s', url)
                 grpc_channel = grpc.insecure_channel(
                     url,
                     # Options to have no limits (-1) on the size of the messages
                     # received or sent over the data plane. The actual buffer size is
                     # controlled in a layer above.
                     options=[("grpc.max_receive_message_length", -1),
                              ("grpc.max_send_message_length", -1)])
                 # Add workerId to the grpc channel
                 grpc_channel = grpc.intercept_channel(
                     grpc_channel, WorkerIdInterceptor())
                 self._state_handler_cache[url] = GrpcStateHandler(
                     beam_fn_api_pb2_grpc.BeamFnStateStub(grpc_channel))
     return self._state_handler_cache[url]
Example #4
0
File: sdk_worker.py Project: x/beam
    def run(self):
        control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
            self._control_channel)
        state_stub = beam_fn_api_pb2_grpc.BeamFnStateStub(
            self._control_channel)
        state_handler = GrpcStateHandler(state_stub)
        state_handler.start()
        self.worker = SdkWorker(state_handler, self._data_channel_factory)

        responses = queue.Queue()
        no_more_work = object()

        def get_responses():
            while True:
                response = responses.get()
                if response is no_more_work:
                    return
                yield response

        for work_request in control_stub.Control(get_responses()):
            logging.info('Got work %s', work_request.instruction_id)
            request_type = work_request.WhichOneof('request')
            if request_type == ['process_bundle_progress']:
                thread_pool = self._progress_thread_pool
            else:
                thread_pool = self._default_work_thread_pool

            # Need this wrapper to capture the original stack trace.
            def do_instruction(request):
                try:
                    return self.worker.do_instruction(request)
                except Exception as e:  # pylint: disable=broad-except
                    traceback_str = traceback.format_exc(e)
                    raise Exception(
                        "Error processing request. Original traceback "
                        "is\n%s\n" % traceback_str)

            def handle_response(request, response_future):
                try:
                    response = response_future.result()
                except Exception as e:  # pylint: disable=broad-except
                    logging.error('Error processing instruction %s',
                                  request.instruction_id,
                                  exc_info=True)
                    response = beam_fn_api_pb2.InstructionResponse(
                        instruction_id=request.instruction_id, error=str(e))
                responses.put(response)

            thread_pool.submit(do_instruction, work_request).add_done_callback(
                functools.partial(handle_response, work_request))

        logging.info("No more requests from control plane")
        logging.info("SDK Harness waiting for in-flight requests to complete")
        # Wait until existing requests are processed.
        self._progress_thread_pool.shutdown()
        self._default_work_thread_pool.shutdown()
        # get_responses may be blocked on responses.get(), but we need to return
        # control to its caller.
        responses.put(no_more_work)
        self._data_channel_factory.close()
        state_handler.done()
        logging.info('Done consuming work.')