示例#1
0
 def test_evict_all(self):
     cache = StateCache(5)
     cache.put("key", "cache_token", "value")
     cache.put("key2", "cache_token", "value2")
     self.assertEqual(len(cache), 2)
     cache.evict_all()
     self.assertEqual(len(cache), 0)
     self.assertEqual(cache.get("key", "cache_token"), None)
     self.assertEqual(cache.get("key2", "cache_token"), None)
示例#2
0
 def test_max_size(self):
     cache = StateCache(2)
     cache.put("key", "cache_token", "value")
     cache.put("key2", "cache_token", "value")
     self.assertEqual(len(cache), 2)
     cache.put("key2", "cache_token", "value")
     self.assertEqual(len(cache), 2)
     cache.put("key", "cache_token", "value")
     self.assertEqual(len(cache), 2)
示例#3
0
class GrpcStateHandlerFactory(StateHandlerFactory):
    """A factory for ``GrpcStateHandler``.

  Caches the created channels by ``state descriptor url``.
  """
    def __init__(self, state_cache_size, credentials=None):
        self._state_handler_cache = {}
        self._lock = threading.Lock()
        self._throwing_state_handler = ThrowingStateHandler()
        self._credentials = credentials
        self._state_cache = StateCache(state_cache_size)

    def create_state_handler(self, api_service_descriptor):
        if not api_service_descriptor:
            return self._throwing_state_handler
        url = api_service_descriptor.url
        if url not in self._state_handler_cache:
            with self._lock:
                if url not in self._state_handler_cache:
                    # Options to have no limits (-1) on the size of the messages
                    # received or sent over the data plane. The actual buffer size is
                    # controlled in a layer above.
                    options = [('grpc.max_receive_message_length', -1),
                               ('grpc.max_send_message_length', -1)]
                    if self._credentials is None:
                        logging.info('Creating insecure state channel for %s.',
                                     url)
                        grpc_channel = GRPCChannelFactory.insecure_channel(
                            url, options=options)
                    else:
                        logging.info('Creating secure state channel for %s.',
                                     url)
                        grpc_channel = GRPCChannelFactory.secure_channel(
                            url, self._credentials, options=options)
                    logging.info('State channel established.')
                    # Add workerId to the grpc channel
                    grpc_channel = grpc.intercept_channel(
                        grpc_channel, WorkerIdInterceptor())
                    self._state_handler_cache[
                        url] = CachingMaterializingStateHandler(
                            self._state_cache,
                            GrpcStateHandler(
                                beam_fn_api_pb2_grpc.BeamFnStateStub(
                                    grpc_channel)))
        return self._state_handler_cache[url]

    def close(self):
        logging.info('Closing all cached gRPC state handlers.')
        for _, state_handler in self._state_handler_cache.items():
            state_handler.done()
        self._state_handler_cache.clear()
        self._state_cache.evict_all()
示例#4
0
 def test_overwrite(self):
     cache = StateCache(2)
     cache.put("key", "cache_token", "value")
     cache.put("key", "cache_token2", "value2")
     self.assertEqual(len(cache), 1)
     self.assertEqual(cache.get("key", "cache_token"), None)
     self.assertEqual(cache.get("key", "cache_token2"), "value2")
示例#5
0
    def __init__(
            self,
            control_address,
            worker_count,
            credentials=None,
            worker_id=None,
            # Caching is disabled by default
            state_cache_size=0,
            profiler_factory=None):
        self._alive = True
        self._worker_count = worker_count
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            logging.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            logging.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        logging.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = {}
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)
        # workers for process/finalize bundle.
        self.workers = queue.Queue()
        # one worker for progress/split request.
        self.progress_worker = SdkWorker(
            self._bundle_processor_cache,
            profiler_factory=self._profiler_factory)
        # one thread is enough for getting the progress report.
        # Assumption:
        # Progress report generation should not do IO or wait on other resources.
        #  Without wait, having multiple threads will not improve performance and
        #  will only add complexity.
        self._progress_thread_pool = futures.ThreadPoolExecutor(max_workers=1)
        # finalize and process share one thread pool.
        self._process_thread_pool = futures.ThreadPoolExecutor(
            max_workers=self._worker_count)
        self._responses = queue.Queue()
        self._process_bundle_queue = queue.Queue()
        self._unscheduled_process_bundle = {}
        logging.info('Initializing SDKHarness with %s workers.',
                     self._worker_count)
示例#6
0
  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               # time-based data buffering is disabled by default
               data_buffer_time_limit_ms=0,
               profiler_factory=None,  # type: Optional[Callable[..., Profile]]
               status_address=None,  # type: Optional[str, unicode]
               ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory
    self._fns = {}  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache)
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
    #  progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue(
    )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
示例#7
0
 def test_put_get(self):
     cache = StateCache(5)
     cache.put("key", "cache_token", "value")
     self.assertEqual(len(cache), 1)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key", "cache_token2"), None)
     with self.assertRaises(Exception):
         self.assertEqual(cache.get("key", None), None)
示例#8
0
    def __init__(
            self,
            control_address,
            credentials=None,
            worker_id=None,
            # Caching is disabled by default
            state_cache_size=0,
            profiler_factory=None):
        self._alive = True
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            _LOGGER.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            _LOGGER.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        _LOGGER.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = {}
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)

        # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
        #  progress once dataflow runner's excessive progress polling is removed.
        self._report_progress_executor = futures.ThreadPoolExecutor(
            max_workers=1)
        self._worker_thread_pool = UnboundedThreadPoolExecutor()
        self._responses = queue.Queue()
        _LOGGER.info(
            'Initializing SDKHarness with unbounded number of workers.')
示例#9
0
 def __init__(self,
              unused_payload,  # type: None
              state,  # type: sdk_worker.StateHandler
              provision_info,  # type: Optional[ExtendedProvisionInfo]
              worker_manager,  # type: WorkerHandlerManager
             ):
   # type: (...) -> None
   super(EmbeddedWorkerHandler, self).__init__(
       self, data_plane.InMemoryDataChannel(), state, provision_info)
   self.control_conn = self  # type: ignore  # need Protocol to describe this
   self.data_conn = self.data_plane_handler
   state_cache = StateCache(STATE_CACHE_SIZE)
   self.bundle_processor_cache = sdk_worker.BundleProcessorCache(
       SingletonStateHandlerFactory(
           sdk_worker.CachingStateHandler(state_cache, state)),
       data_plane.InMemoryDataChannelFactory(
           self.data_plane_handler.inverse()),
       worker_manager._process_bundle_descriptors)
   self.worker = sdk_worker.SdkWorker(
       self.bundle_processor_cache,
       state_cache_metrics_fn=state_cache.get_monitoring_infos)
   self._uid_counter = 0
示例#10
0
  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               profiler_factory=None  # type: Optional[Callable[..., Profile]]
              ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id)
    self._state_handler_factory = GrpcStateHandlerFactory(self._state_cache,
                                                          credentials)
    self._profiler_factory = profiler_factory
    self._fns = {}  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue()  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
示例#11
0
  def __init__(
      self,
      control_address,  # type: str
      credentials=None,  # type: Optional[grpc.ChannelCredentials]
      worker_id=None,  # type: Optional[str]
      # Caching is disabled by default
      state_cache_size=0,  # type: int
      # time-based data buffering is disabled by default
      data_buffer_time_limit_ms=0,  # type: int
      profiler_factory=None,  # type: Optional[Callable[..., Profile]]
      status_address=None,  # type: Optional[str]
      # Heap dump through status api is disabled by default
      enable_heap_dump=False,  # type: bool
  ):
    # type: (...) -> None
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    options = [('grpc.max_receive_message_length', -1),
               ('grpc.max_send_message_length', -1)]
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address, options=options)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials, options=options)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory

    def default_factory(id):
      # type: (str) -> beam_fn_api_pb2.ProcessBundleDescriptor
      return self._control_stub.GetProcessBundleDescriptor(
          beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
              process_bundle_descriptor_id=id))

    self._fns = KeyedDefaultDict(default_factory)
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache,
            enable_heap_dump)  # type: Optional[FnApiWorkerStatusHandler]
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common
    # thread_pool_executor.shared_unbounded_instance() to process bundle
    # progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance()
    self._responses = queue.Queue(
    )  # type: queue.Queue[Union[beam_fn_api_pb2.InstructionResponse, Sentinel]]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
示例#12
0
 def get_cache(size):
     cache = StateCache(size)
     cache.initialize_metrics()
     return cache
示例#13
0
 def __init__(self, state_cache_size, credentials=None):
     self._state_handler_cache = {}
     self._lock = threading.Lock()
     self._throwing_state_handler = ThrowingStateHandler()
     self._credentials = credentials
     self._state_cache = StateCache(state_cache_size)
示例#14
0
 def test_lru(self):
     cache = StateCache(5)
     cache.put("key", "cache_token", "value")
     cache.put("key2", "cache_token2", "value2")
     cache.put("key3", "cache_token", "value0")
     cache.put("key3", "cache_token", "value3")
     cache.put("key4", "cache_token4", "value4")
     cache.put("key5", "cache_token", "value0")
     cache.put("key5", "cache_token", ["value5"])
     self.assertEqual(len(cache), 5)
     self.assertEqual(cache.get("key", "cache_token"), "value")
     self.assertEqual(cache.get("key2", "cache_token2"), "value2")
     self.assertEqual(cache.get("key3", "cache_token"), "value3")
     self.assertEqual(cache.get("key4", "cache_token4"), "value4")
     self.assertEqual(cache.get("key5", "cache_token"), ["value5"])
     # insert another key to trigger cache eviction
     cache.put("key6", "cache_token2", "value7")
     self.assertEqual(len(cache), 5)
     # least recently used key should be gone ("key")
     self.assertEqual(cache.get("key", "cache_token"), None)
     # trigger a read on "key2"
     cache.get("key2", "cache_token")
     # insert another key to trigger cache eviction
     cache.put("key7", "cache_token", "value7")
     self.assertEqual(len(cache), 5)
     # least recently used key should be gone ("key3")
     self.assertEqual(cache.get("key3", "cache_token"), None)
     # trigger a put on "key2"
     cache.put("key2", "cache_token", "put")
     self.assertEqual(len(cache), 5)
     # insert another key to trigger cache eviction
     cache.put("key8", "cache_token", "value8")
     self.assertEqual(len(cache), 5)
     # least recently used key should be gone ("key4")
     self.assertEqual(cache.get("key4", "cache_token"), None)
     # make "key5" used by appending to it
     cache.extend("key5", "cache_token", ["another"])
     # least recently used key should be gone ("key6")
     self.assertEqual(cache.get("key6", "cache_token"), None)
示例#15
0
 def test_is_cached_enabled(self):
     cache = StateCache(1)
     self.assertEqual(cache.is_cache_enabled(), True)
     cache = StateCache(0)
     self.assertEqual(cache.is_cache_enabled(), False)
示例#16
0
 def test_clear(self):
     cache = StateCache(5)
     cache.clear("new-key", "cache_token")
     cache.put("key", "cache_token", ["value"])
     self.assertEqual(len(cache), 2)
     self.assertEqual(cache.get("new-key", "new_token"), None)
     self.assertEqual(cache.get("key", "cache_token"), ['value'])
     # test clear without existing key/token
     cache.clear("non-existing", "token")
     self.assertEqual(len(cache), 3)
     self.assertEqual(cache.get("non-existing", "token"), [])
     # test eviction in case the cache token changes
     cache.clear("new-key", "wrong_token")
     self.assertEqual(len(cache), 2)
     self.assertEqual(cache.get("new-key", "cache_token"), None)
     self.assertEqual(cache.get("new-key", "wrong_token"), None)
示例#17
0
 def test_extend(self):
     cache = StateCache(3)
     cache.put("key", "cache_token", ['val'])
     # test extend for existing key
     cache.extend("key", "cache_token", ['yet', 'another', 'val'])
     self.assertEqual(len(cache), 1)
     self.assertEqual(cache.get("key", "cache_token"),
                      ['val', 'yet', 'another', 'val'])
     # test extend without existing key
     cache.extend("key2", "cache_token", ['another', 'val'])
     self.assertEqual(len(cache), 2)
     self.assertEqual(cache.get("key2", "cache_token"), ['another', 'val'])
     # test eviction in case the cache token changes
     cache.extend("key2", "new_token", ['new_value'])
     self.assertEqual(cache.get("key2", "new_token"), None)
     self.assertEqual(len(cache), 1)
示例#18
0
 def test_empty_cache_get(self):
     cache = StateCache(5)
     self.assertEqual(cache.get("key", 'cache_token'), None)
     with self.assertRaises(Exception):
         self.assertEqual(cache.get("key", None), None)