def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               # time-based data buffering is disabled by default
               data_buffer_time_limit_ms=0,
               profiler_factory=None,  # type: Optional[Callable[..., Profile]]
               status_address=None,  # type: Optional[str, unicode]
               ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory
    self._fns = {}  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache)
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
    #  progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue(
    )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
Exemple #2
0
    def run(self):
        logging_server = grpc.server(UnboundedThreadPoolExecutor())
        logging_port = logging_server.add_insecure_port('[::]:0')
        logging_server.start()
        logging_servicer = BeamFnLoggingServicer()
        beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
            logging_servicer, logging_server)
        logging_descriptor = text_format.MessageToString(
            endpoints_pb2.ApiServiceDescriptor(url='localhost:%s' %
                                               logging_port))

        control_descriptor = text_format.MessageToString(
            endpoints_pb2.ApiServiceDescriptor(url=self._control_address))

        env_dict = dict(os.environ,
                        CONTROL_API_SERVICE_DESCRIPTOR=control_descriptor,
                        LOGGING_API_SERVICE_DESCRIPTOR=logging_descriptor)
        # only add worker_id when it is set.
        if self._worker_id:
            env_dict['WORKER_ID'] = self._worker_id

        with fn_api_runner.SUBPROCESS_LOCK:
            p = subprocess.Popen(self._worker_command_line,
                                 shell=True,
                                 env=env_dict)
        try:
            p.wait()
            if p.returncode:
                raise RuntimeError(
                    'Worker subprocess exited with return code %s' %
                    p.returncode)
        finally:
            if p.poll() is None:
                p.kill()
            logging_server.stop(0)
Exemple #3
0
    def start(
            cls,
            use_process=False,
            port=0,
            state_cache_size=0,
            data_buffer_time_limit_ms=-1,
            container_executable=None  # type: Optional[str]
    ):
        # type: (...) -> Tuple[str, grpc.Server]
        worker_server = grpc.server(UnboundedThreadPoolExecutor())
        worker_address = 'localhost:%s' % worker_server.add_insecure_port(
            '[::]:%s' % port)
        worker_pool = cls(use_process=use_process,
                          container_executable=container_executable,
                          state_cache_size=state_cache_size,
                          data_buffer_time_limit_ms=data_buffer_time_limit_ms)
        beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
            worker_pool, worker_server)
        worker_server.start()

        # Register to kill the subprocesses on exit.
        def kill_worker_processes():
            for worker_process in worker_pool._worker_processes.values():
                worker_process.kill()

        atexit.register(kill_worker_processes)

        return worker_address, worker_server
Exemple #4
0
  def _check_fn_registration_multi_request(self, *args):
    """Check the function registration calls to the sdk_harness.

    Args:
     tuple of request_count, number of process_bundles per request and workers
     counts to process the request.
    """
    for (request_count, process_bundles_per_request) in args:
      requests = []
      process_bundle_descriptors = []

      for i in range(request_count):
        pbd = self._get_process_bundles(i, process_bundles_per_request)
        process_bundle_descriptors.extend(pbd)
        requests.append(
            beam_fn_api_pb2.InstructionRequest(
                instruction_id=str(i),
                register=beam_fn_api_pb2.RegisterRequest(
                    process_bundle_descriptor=process_bundle_descriptors)))

      test_controller = BeamFnControlServicer(requests)

      server = grpc.server(UnboundedThreadPoolExecutor())
      beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
          test_controller, server)
      test_port = server.add_insecure_port("[::]:0")
      server.start()

      harness = sdk_worker.SdkHarness(
          "localhost:%s" % test_port, state_cache_size=100)
      harness.run()

      self.assertEqual(harness._bundle_processor_cache.fns,
                       {item.id: item
                        for item in process_bundle_descriptors})
Exemple #5
0
    def _stage_files(self, files):
        """Utility method to stage files.

      Args:
        files: a list of tuples of the form [(local_name, remote_name),...]
          describing the name of the artifacts in local temp folder and desired
          name in staging location.
    """
        server = grpc.server(UnboundedThreadPoolExecutor())
        staging_service = TestLocalFileSystemArtifactStagingServiceServicer(
            self._remote_dir)
        beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
            staging_service, server)
        test_port = server.add_insecure_port('[::]:0')
        server.start()
        stager = portable_stager.PortableStager(
            artifact_service_channel=grpc.insecure_channel('localhost:%s' %
                                                           test_port),
            staging_session_token='token')
        for from_file, to_file in files:
            stager.stage_artifact(local_path_to_artifact=os.path.join(
                self._temp_dir, from_file),
                                  artifact_name=to_file)
        stager.commit_manifest()
        return staging_service.manifest.artifact, staging_service.retrieval_tokens
Exemple #6
0
  def _grpc_data_channel_test(self, time_based_flush=False):
    if time_based_flush:
      data_servicer = data_plane.BeamFnDataServicer(
          data_buffer_time_limit_ms=100)
    else:
      data_servicer = data_plane.BeamFnDataServicer()
    worker_id = 'worker_0'
    data_channel_service = \
      data_servicer.get_conn_by_worker_id(worker_id)

    server = grpc.server(UnboundedThreadPoolExecutor())
    beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
        data_servicer, server)
    test_port = server.add_insecure_port('[::]:0')
    server.start()

    grpc_channel = grpc.insecure_channel('localhost:%s' % test_port)
    # Add workerId to the grpc channel
    grpc_channel = grpc.intercept_channel(
        grpc_channel, WorkerIdInterceptor(worker_id))
    data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)
    if time_based_flush:
      data_channel_client = data_plane.GrpcClientDataChannel(
          data_channel_stub, data_buffer_time_limit_ms=100)
    else:
      data_channel_client = data_plane.GrpcClientDataChannel(data_channel_stub)

    try:
      self._data_channel_test(
          data_channel_service, data_channel_client, time_based_flush)
    finally:
      data_channel_client.close()
      data_channel_service.close()
      data_channel_client.wait()
      data_channel_service.wait()
Exemple #7
0
    def __init__(
            self,
            control_address,
            credentials=None,
            worker_id=None,
            # Caching is disabled by default
            state_cache_size=0,
            profiler_factory=None):
        self._alive = True
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            _LOGGER.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            _LOGGER.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        _LOGGER.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = {}
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)

        # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
        #  progress once dataflow runner's excessive progress polling is removed.
        self._report_progress_executor = futures.ThreadPoolExecutor(
            max_workers=1)
        self._worker_thread_pool = UnboundedThreadPoolExecutor()
        self._responses = queue.Queue()
        _LOGGER.info(
            'Initializing SDKHarness with unbounded number of workers.')
Exemple #8
0
    def test_concurrent_requests(self):

        num_sessions = 7
        artifacts = collections.defaultdict(list)

        def name(index):
            # Overlapping names across sessions.
            return 'name%d' % (index // num_sessions)

        def session(index):
            return 'session%d' % (index % num_sessions)

        def delayed_data(data, index, max_msecs=1):
            time.sleep(max_msecs / 1000.0 * random.random())
            return ('%s_%d' % (data, index)).encode('ascii')

        def put(index):
            artifacts[session(index)].append(
                beam_artifact_api_pb2.ArtifactMetadata(name=name(index)))
            self._service.PutArtifact([
                self.put_metadata(session(index), name(index)),
                self.put_data(delayed_data('a', index)),
                self.put_data(delayed_data('b' * 20, index, 2))
            ])
            return session(index)

        def commit(session):
            return session, self._service.CommitManifest(
                beam_artifact_api_pb2.CommitManifestRequest(
                    staging_session_token=session,
                    manifest=beam_artifact_api_pb2.Manifest(
                        artifact=artifacts[session]))).retrieval_token

        def check(index):
            self.assertEqual(
                delayed_data('a', index) + delayed_data('b' * 20, index, 2),
                self.retrieve_artifact(self._service, tokens[session(index)],
                                       name(index)))

        # pylint: disable=range-builtin-not-iterating
        pool = UnboundedThreadPoolExecutor()
        sessions = set(pool.map(put, range(100)))
        tokens = dict(pool.map(commit, sessions))
        # List forces materialization.
        _ = list(pool.map(check, range(100)))
Exemple #9
0
    def process_bundle(
        self,
        inputs,  # type: Mapping[str, execution.PartitionableBuffer]
        expected_outputs,  # type: DataOutput
        fired_timers,  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
        expected_output_timers,  # type: Dict[Tuple[str, str], str]
        dry_run=False,
    ):
        # type: (...) -> BundleProcessResult
        part_inputs = [{} for _ in range(self._num_workers)
                       ]  # type: List[Dict[str, List[bytes]]]
        # Timers are only executed on the first worker
        # TODO(BEAM-9741): Split timers to multiple workers
        timer_inputs = [
            fired_timers if i == 0 else {} for i in range(self._num_workers)
        ]
        for name, input in inputs.items():
            for ix, part in enumerate(input.partition(self._num_workers)):
                part_inputs[ix][name] = part

        merged_result = None  # type: Optional[beam_fn_api_pb2.InstructionResponse]
        split_result_list = [
        ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]

        def execute(part_map_input_timers):
            # type: (...) -> BundleProcessResult
            part_map, input_timers = part_map_input_timers
            bundle_manager = BundleManager(
                self.bundle_context_manager,
                self._progress_frequency,
                cache_token_generator=self._cache_token_generator)
            return bundle_manager.process_bundle(part_map, expected_outputs,
                                                 input_timers,
                                                 expected_output_timers,
                                                 dry_run)

        with UnboundedThreadPoolExecutor() as executor:
            for result, split_result in executor.map(
                    execute,
                    zip(
                        part_inputs,  # pylint: disable=zip-builtin-not-iterating
                        timer_inputs)):
                split_result_list += split_result
                if merged_result is None:
                    merged_result = result
                else:
                    merged_result = beam_fn_api_pb2.InstructionResponse(
                        process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                            monitoring_infos=monitoring_infos.consolidate(
                                itertools.chain(
                                    result.process_bundle.monitoring_infos,
                                    merged_result.process_bundle.
                                    monitoring_infos))),
                        error=result.error or merged_result.error)
        assert merged_result is not None
        return merged_result, split_result_list
Exemple #10
0
 def setUp(self):
     self.num_request = 3
     self.test_status_service = BeamFnStatusServicer(self.num_request)
     self.server = grpc.server(UnboundedThreadPoolExecutor())
     beam_fn_api_pb2_grpc.add_BeamFnWorkerStatusServicer_to_server(
         self.test_status_service, self.server)
     self.test_port = self.server.add_insecure_port('[::]:0')
     self.server.start()
     self.url = 'localhost:%s' % self.test_port
     self.fn_status_handler = FnApiWorkerStatusHandler(self.url)
Exemple #11
0
    def test_shutdown_with_slow_workers(self):
        futures = []
        with UnboundedThreadPoolExecutor() as executor:
            for _ in range(0, 5):
                futures.append(executor.submit(self.append_and_sleep, 1))

        for future in futures:
            future.result(timeout=10)

        with self._lock:
            self.assertEqual(5, len(self._worker_idents))
Exemple #12
0
 def start_grpc_server(self, port=0):
     self._server = grpc.server(UnboundedThreadPoolExecutor())
     port = self._server.add_insecure_port('localhost:%d' % port)
     beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(
         self, self._server)
     beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
         self._artifact_service, self._server)
     self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
         url='localhost:%d' % port)
     self._server.start()
     _LOGGER.info('Grpc server started on port %s', port)
     return port
Exemple #13
0
  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               profiler_factory=None  # type: Optional[Callable[..., Profile]]
              ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id)
    self._state_handler_factory = GrpcStateHandlerFactory(self._state_cache,
                                                          credentials)
    self._profiler_factory = profiler_factory
    self._fns = {}  # type: Dict[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue()  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')
Exemple #14
0
 def start_grpc_server(self, port=0):
   self._server = grpc.server(UnboundedThreadPoolExecutor())
   port = self._server.add_insecure_port(
       '%s:%d' % (self.get_bind_address(), port))
   beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server)
   beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
       self._artifact_service, self._server)
   hostname = self.get_service_address()
   self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
       url='%s:%d' % (hostname, port))
   self._server.start()
   _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port))
   return port
Exemple #15
0
    def test_exception_propagation(self):
        with UnboundedThreadPoolExecutor() as executor:
            future = executor.submit(self.raise_error, 'footest')

        try:
            future.result()
        except Exception:
            message = traceback.format_exc()
        else:
            raise AssertionError('expected exception not raised')

        self.assertIn('footest', message)
        self.assertIn('raise_error', message)
Exemple #16
0
    def setUp(self):
        self.test_logging_service = BeamFnLoggingServicer()
        self.server = grpc.server(UnboundedThreadPoolExecutor())
        beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
            self.test_logging_service, self.server)
        self.test_port = self.server.add_insecure_port('[::]:0')
        self.server.start()

        self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
        self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
        self.fn_log_handler = log_handler.FnApiLogRecordHandler(
            self.logging_service_descriptor)
        logging.getLogger().setLevel(logging.INFO)
        logging.getLogger().addHandler(self.fn_log_handler)
Exemple #17
0
    def test_worker_reuse(self):
        futures = []
        with UnboundedThreadPoolExecutor() as executor:
            for _ in range(0, 5):
                futures.append(executor.submit(self.append_and_sleep, 0.01))
            time.sleep(3)
            for _ in range(0, 5):
                futures.append(executor.submit(self.append_and_sleep, 0.01))

        for future in futures:
            future.result(timeout=10)

        with self._lock:
            self.assertEqual(10, len(self._worker_idents))
            self.assertTrue(len(set(self._worker_idents)) < 10)
Exemple #18
0
  def process_bundle(self,
                     inputs,  # type: Mapping[str, PartitionableBuffer]
                     expected_outputs,  # type: DataOutput
                     fired_timers,  # type: Mapping[Tuple[str, str], PartitionableBuffer]
                     expected_output_timers  # type: Dict[Tuple[str, str], str]
                     ):
    # type: (...) -> BundleProcessResult
    part_inputs = [{} for _ in range(self._num_workers)
                   ]  # type: List[Dict[str, List[bytes]]]
    for name, input in inputs.items():
      for ix, part in enumerate(input.partition(self._num_workers)):
        part_inputs[ix][name] = part

    merged_result = None  # type: Optional[beam_fn_api_pb2.InstructionResponse]
    split_result_list = [
    ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]

    def execute(part_map):
      # type: (...) -> BundleProcessResult
      bundle_manager = BundleManager(
          self._worker_handler_list,
          self._get_buffer,
          self._get_input_coder_impl,
          self._bundle_descriptor,
          self._progress_frequency,
          self._registered,
          cache_token_generator=self._cache_token_generator)
      return bundle_manager.process_bundle(
          part_map, expected_outputs, fired_timers, expected_output_timers)

    with UnboundedThreadPoolExecutor() as executor:
      for result, split_result in executor.map(execute, part_inputs):

        split_result_list += split_result
        if merged_result is None:
          merged_result = result
        else:
          merged_result = beam_fn_api_pb2.InstructionResponse(
              process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                  monitoring_infos=monitoring_infos.consolidate(
                      itertools.chain(
                          result.process_bundle.monitoring_infos,
                          merged_result.process_bundle.monitoring_infos))),
              error=result.error or merged_result.error)
    assert merged_result is not None

    return merged_result, split_result_list
 def test_with_grpc(self):
     server = grpc.server(UnboundedThreadPoolExecutor())
     try:
         beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
             self._service, server)
         beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
             self._service, server)
         port = server.add_insecure_port('[::]:0')
         server.start()
         channel = grpc.insecure_channel('localhost:%d' % port)
         self._run_staging(
             beam_artifact_api_pb2_grpc.ArtifactStagingServiceStub(channel),
             beam_artifact_api_pb2_grpc.ArtifactRetrievalServiceStub(
                 channel))
         channel.close()
     finally:
         server.stop(1)
Exemple #20
0
def main(unused_argv):
  parser = argparse.ArgumentParser()
  parser.add_argument('-p', '--port',
                      type=int,
                      help='port on which to serve the job api')
  options = parser.parse_args()
  global server
  server = grpc.server(UnboundedThreadPoolExecutor())
  beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
      expansion_service.ExpansionServiceServicer(PipelineOptions()), server
  )
  server.add_insecure_port('localhost:{}'.format(options.port))
  server.start()
  _LOGGER.info('Listening for expansion requests at %d', options.port)

  signal.signal(signal.SIGTERM, cleanup)
  signal.signal(signal.SIGINT, cleanup)
  # blocking main thread forever.
  signal.pause()
Exemple #21
0
  def start(cls, use_process=False, port=0,
            state_cache_size=0, container_executable=None):
    worker_server = grpc.server(UnboundedThreadPoolExecutor())
    worker_address = 'localhost:%s' % worker_server.add_insecure_port(
        '[::]:%s' % port)
    worker_pool = cls(use_process=use_process,
                      container_executable=container_executable,
                      state_cache_size=state_cache_size)
    beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
        worker_pool,
        worker_server)
    worker_server.start()

    # Register to kill the subprocesses on exit.
    def kill_worker_processes():
      for worker_process in worker_pool._worker_processes.values():
        worker_process.kill()
    atexit.register(kill_worker_processes)

    return worker_address, worker_server
Exemple #22
0
def main(unused_argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--port',
                        type=int,
                        help='port on which to serve the job api')
    options = parser.parse_args()
    global server
    server = grpc.server(UnboundedThreadPoolExecutor())

    # DOCKER SDK Harness
    beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
        expansion_service.ExpansionServiceServicer(
            PipelineOptions([
                "--experiments", "beam_fn_api", "--sdk_location", "container"
            ])), server)

    # PROCESS SDK Harness
    # beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
    #     expansion_service.ExpansionServiceServicer(
    #         PipelineOptions.from_dictionary({
    #             'environment_type': 'PROCESS',
    #             'environment_config': '{"command": "sdks/python/container/build/target/launcher/darwin_amd64/boot"}',
    #             'experiments': 'beam_fn_api',
    #             'sdk_location': 'container',
    #         })
    #     ), server
    # )

    server.add_insecure_port('localhost:{}'.format(options.port))
    server.start()
    _LOGGER.info('Listening for expansion requests at %d', options.port)

    signal.signal(signal.SIGTERM, cleanup)
    signal.signal(signal.SIGINT, cleanup)
    # blocking main thread forever.
    signal.pause()
Exemple #23
0
 def test_shutdown_with_no_workers(self):
     with UnboundedThreadPoolExecutor():
         pass
Exemple #24
0
    def test_map(self):
        with UnboundedThreadPoolExecutor() as executor:
            executor.map(self.append_and_sleep, itertools.repeat(0.01, 5))

        with self._lock:
            self.assertEqual(5, len(self._worker_idents))
Exemple #25
0
class SdkHarness(object):
    REQUEST_METHOD_PREFIX = '_request_'

    def __init__(
            self,
            control_address,
            credentials=None,
            worker_id=None,
            # Caching is disabled by default
            state_cache_size=0,
            profiler_factory=None):
        self._alive = True
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            _LOGGER.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            _LOGGER.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        _LOGGER.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = {}
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)
        self._worker_thread_pool = UnboundedThreadPoolExecutor()
        self._responses = queue.Queue()
        _LOGGER.info(
            'Initializing SDKHarness with unbounded number of workers.')

    def run(self):
        control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
            self._control_channel)
        no_more_work = object()

        def get_responses():
            while True:
                response = self._responses.get()
                if response is no_more_work:
                    return
                yield response

        self._alive = True

        try:
            for work_request in control_stub.Control(get_responses()):
                _LOGGER.debug('Got work %s', work_request.instruction_id)
                request_type = work_request.WhichOneof('request')
                # Name spacing the request method with 'request_'. The called method
                # will be like self.request_register(request)
                getattr(self, SdkHarness.REQUEST_METHOD_PREFIX +
                        request_type)(work_request)
        finally:
            self._alive = False

        _LOGGER.info('No more requests from control plane')
        _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
        # Wait until existing requests are processed.
        self._worker_thread_pool.shutdown()
        # get_responses may be blocked on responses.get(), but we need to return
        # control to its caller.
        self._responses.put(no_more_work)
        # Stop all the workers and clean all the associated resources
        self._data_channel_factory.close()
        self._state_handler_factory.close()
        self._bundle_processor_cache.shutdown()
        _LOGGER.info('Done consuming work.')

    def _execute(self, task, request):
        try:
            response = task()
        except Exception:  # pylint: disable=broad-except
            traceback_string = traceback.format_exc()
            print(traceback_string, file=sys.stderr)
            _LOGGER.error(
                'Error processing instruction %s. Original traceback is\n%s\n',
                request.instruction_id, traceback_string)
            response = beam_fn_api_pb2.InstructionResponse(
                instruction_id=request.instruction_id, error=traceback_string)
        self._responses.put(response)

    def _request_register(self, request):
        self._request_execute(request)

    def _request_process_bundle(self, request):
        def task():
            self._execute(lambda: self.create_worker().do_instruction(request),
                          request)

        self._worker_thread_pool.submit(task)
        _LOGGER.debug("Currently using %s threads." %
                      len(self._worker_thread_pool._workers))

    def _request_process_bundle_split(self, request):
        self._request_process_bundle_action(request)

    def _request_process_bundle_progress(self, request):
        self._request_process_bundle_action(request)

    def _request_process_bundle_action(self, request):
        def task():
            instruction_id = getattr(
                request, request.WhichOneof('request')).instruction_id
            # only process progress/split request when a bundle is in processing.
            if (instruction_id
                    in self._bundle_processor_cache.active_bundle_processors):
                self._execute(
                    lambda: self.create_worker().do_instruction(request),
                    request)
            else:
                self._execute(
                    lambda: beam_fn_api_pb2.InstructionResponse(
                        instruction_id=request.instruction_id,
                        error=('Unknown process bundle instruction {}').format(
                            instruction_id)), request)

        self._worker_thread_pool.submit(task)

    def _request_finalize_bundle(self, request):
        self._request_execute(request)

    def _request_execute(self, request):
        def task():
            self._execute(lambda: self.create_worker().do_instruction(request),
                          request)

        self._worker_thread_pool.submit(task)

    def create_worker(self):
        return SdkWorker(
            self._bundle_processor_cache,
            state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
            profiler_factory=self._profiler_factory)
Exemple #26
0
class SdkHarness(object):
  REQUEST_METHOD_PREFIX = '_request_'

  def __init__(self,
               control_address,  # type: str
               credentials=None,
               worker_id=None,  # type: Optional[str]
               # Caching is disabled by default
               state_cache_size=0,
               # time-based data buffering is disabled by default
               data_buffer_time_limit_ms=0,
               profiler_factory=None,  # type: Optional[Callable[..., Profile]]
               status_address=None,  # type: Optional[str]
               ):
    self._alive = True
    self._worker_index = 0
    self._worker_id = worker_id
    self._state_cache = StateCache(state_cache_size)
    if credentials is None:
      _LOGGER.info('Creating insecure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.insecure_channel(
          control_address)
    else:
      _LOGGER.info('Creating secure control channel for %s.', control_address)
      self._control_channel = GRPCChannelFactory.secure_channel(
          control_address, credentials)
    grpc.channel_ready_future(self._control_channel).result(timeout=60)
    _LOGGER.info('Control channel established.')

    self._control_channel = grpc.intercept_channel(
        self._control_channel, WorkerIdInterceptor(self._worker_id))
    self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
        credentials, self._worker_id, data_buffer_time_limit_ms)
    self._state_handler_factory = GrpcStateHandlerFactory(
        self._state_cache, credentials)
    self._profiler_factory = profiler_factory
    self._fns = KeyedDefaultDict(
        lambda id: self._control_stub.GetProcessBundleDescriptor(
            beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
                process_bundle_descriptor_id=id))
    )  # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
    # BundleProcessor cache across all workers.
    self._bundle_processor_cache = BundleProcessorCache(
        state_handler_factory=self._state_handler_factory,
        data_channel_factory=self._data_channel_factory,
        fns=self._fns)

    if status_address:
      try:
        self._status_handler = FnApiWorkerStatusHandler(
            status_address, self._bundle_processor_cache
        )  # type: Optional[FnApiWorkerStatusHandler]
      except Exception:
        traceback_string = traceback.format_exc()
        _LOGGER.warning(
            'Error creating worker status request handler, '
            'skipping status report. Trace back: %s' % traceback_string)
    else:
      self._status_handler = None

    # TODO(BEAM-8998) use common UnboundedThreadPoolExecutor to process bundle
    #  progress once dataflow runner's excessive progress polling is removed.
    self._report_progress_executor = futures.ThreadPoolExecutor(max_workers=1)
    self._worker_thread_pool = UnboundedThreadPoolExecutor()
    self._responses = queue.Queue(
    )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
    _LOGGER.info('Initializing SDKHarness with unbounded number of workers.')

  def run(self):
    self._control_stub = beam_fn_api_pb2_grpc.BeamFnControlStub(
        self._control_channel)
    no_more_work = object()

    def get_responses():
      # type: () -> Iterator[beam_fn_api_pb2.InstructionResponse]
      while True:
        response = self._responses.get()
        if response is no_more_work:
          return
        yield response

    self._alive = True

    try:
      for work_request in self._control_stub.Control(get_responses()):
        _LOGGER.debug('Got work %s', work_request.instruction_id)
        request_type = work_request.WhichOneof('request')
        # Name spacing the request method with 'request_'. The called method
        # will be like self.request_register(request)
        getattr(self, SdkHarness.REQUEST_METHOD_PREFIX + request_type)(
            work_request)
    finally:
      self._alive = False

    _LOGGER.info('No more requests from control plane')
    _LOGGER.info('SDK Harness waiting for in-flight requests to complete')
    # Wait until existing requests are processed.
    self._worker_thread_pool.shutdown()
    # get_responses may be blocked on responses.get(), but we need to return
    # control to its caller.
    self._responses.put(no_more_work)
    # Stop all the workers and clean all the associated resources
    self._data_channel_factory.close()
    self._state_handler_factory.close()
    self._bundle_processor_cache.shutdown()
    if self._status_handler:
      self._status_handler.close()
    _LOGGER.info('Done consuming work.')

  def _execute(self,
               task,  # type: Callable[[], beam_fn_api_pb2.InstructionResponse]
               request  # type:  beam_fn_api_pb2.InstructionRequest
              ):
    # type: (...) -> None
    with statesampler.instruction_id(request.instruction_id):
      try:
        response = task()
      except Exception:  # pylint: disable=broad-except
        traceback_string = traceback.format_exc()
        print(traceback_string, file=sys.stderr)
        _LOGGER.error(
            'Error processing instruction %s. Original traceback is\n%s\n',
            request.instruction_id,
            traceback_string)
        response = beam_fn_api_pb2.InstructionResponse(
            instruction_id=request.instruction_id, error=traceback_string)
      self._responses.put(response)

  def _request_register(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    # registration request is handled synchronously
    self._execute(lambda: self.create_worker().do_instruction(request), request)

  def _request_process_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_process_bundle_split(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_progress(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_process_bundle_action(request)

  def _request_process_bundle_action(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None

    def task():
      instruction_id = getattr(
          request, request.WhichOneof('request')).instruction_id
      # only process progress/split request when a bundle is in processing.
      if (instruction_id in
          self._bundle_processor_cache.active_bundle_processors):
        self._execute(
            lambda: self.create_worker().do_instruction(request), request)
      else:
        self._execute(
            lambda: beam_fn_api_pb2.InstructionResponse(
                instruction_id=request.instruction_id,
                error=('Unknown process bundle instruction {}').format(
                    instruction_id)),
            request)

    self._report_progress_executor.submit(task)

  def _request_finalize_bundle(self, request):
    # type: (beam_fn_api_pb2.InstructionRequest) -> None
    self._request_execute(request)

  def _request_execute(self, request):
    def task():
      self._execute(
          lambda: self.create_worker().do_instruction(request), request)

    self._worker_thread_pool.submit(task)
    _LOGGER.debug(
        "Currently using %s threads." % len(self._worker_thread_pool._workers))

  def create_worker(self):
    return SdkWorker(
        self._bundle_processor_cache,
        state_cache_metrics_fn=self._state_cache.get_monitoring_infos,
        profiler_factory=self._profiler_factory)
Exemple #27
0
    def __init__(
            self,
            state,  # type: StateServicer
            provision_info,  # type: Optional[ExtendedProvisionInfo]
            worker_manager,  # type: WorkerHandlerManager
    ):
        # type: (...) -> None
        self.state = state
        self.provision_info = provision_info
        self.control_server = grpc.server(UnboundedThreadPoolExecutor())
        self.control_port = self.control_server.add_insecure_port('[::]:0')
        self.control_address = 'localhost:%s' % self.control_port

        # Options to have no limits (-1) on the size of the messages
        # received or sent over the data plane. The actual buffer size
        # is controlled in a layer above.
        no_max_message_sizes = [("grpc.max_receive_message_length", -1),
                                ("grpc.max_send_message_length", -1)]
        self.data_server = grpc.server(UnboundedThreadPoolExecutor(),
                                       options=no_max_message_sizes)
        self.data_port = self.data_server.add_insecure_port('[::]:0')

        self.state_server = grpc.server(UnboundedThreadPoolExecutor(),
                                        options=no_max_message_sizes)
        self.state_port = self.state_server.add_insecure_port('[::]:0')

        self.control_handler = BeamFnControlServicer(worker_manager)
        beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
            self.control_handler, self.control_server)

        # If we have provision info, serve these off the control port as well.
        if self.provision_info:
            if self.provision_info.provision_info:
                beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server(
                    BasicProvisionService(self.provision_info.provision_info,
                                          worker_manager), self.control_server)

            if self.provision_info.artifact_staging_dir:
                service = artifact_service.BeamFilesystemArtifactService(
                    self.provision_info.artifact_staging_dir
                )  # type: beam_artifact_api_pb2_grpc.LegacyArtifactRetrievalServiceServicer
            else:
                service = EmptyArtifactRetrievalService()
            beam_artifact_api_pb2_grpc.add_LegacyArtifactRetrievalServiceServicer_to_server(
                service, self.control_server)

        self.data_plane_handler = data_plane.BeamFnDataServicer(
            DATA_BUFFER_TIME_LIMIT_MS)
        beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
            self.data_plane_handler, self.data_server)

        beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
            GrpcStateServicer(state), self.state_server)

        self.logging_server = grpc.server(UnboundedThreadPoolExecutor(),
                                          options=no_max_message_sizes)
        self.logging_port = self.logging_server.add_insecure_port('[::]:0')
        beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
            BasicLoggingService(), self.logging_server)

        _LOGGER.info('starting control server on port %s', self.control_port)
        _LOGGER.info('starting data server on port %s', self.data_port)
        _LOGGER.info('starting state server on port %s', self.state_port)
        _LOGGER.info('starting logging server on port %s', self.logging_port)
        self.logging_server.start()
        self.state_server.start()
        self.data_server.start()
        self.control_server.start()