Exemplo n.º 1
0
    def test_shared_shutdown_does_nothing(self):
        thread_pool_executor.shared_unbounded_instance().shutdown()

        futures = []
        with thread_pool_executor.shared_unbounded_instance() as executor:
            for _ in range(0, 5):
                futures.append(executor.submit(self.append_and_sleep, 0.01))

        for future in futures:
            future.result(timeout=10)

        with self._lock:
            self.assertEqual(5, len(self._worker_idents))
Exemplo n.º 2
0
    def start(
            cls,
            use_process=False,
            port=0,
            state_cache_size=0,
            data_buffer_time_limit_ms=-1,
            container_executable=None  # type: Optional[str]
    ):
        # type: (...) -> Tuple[str, grpc.Server]
        options = [("grpc.http2.max_pings_without_data", 0),
                   ("grpc.http2.max_ping_strikes", 0)]
        worker_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance(), options=options)
        worker_address = 'localhost:%s' % worker_server.add_insecure_port(
            '[::]:%s' % port)
        worker_pool = cls(use_process=use_process,
                          container_executable=container_executable,
                          state_cache_size=state_cache_size,
                          data_buffer_time_limit_ms=data_buffer_time_limit_ms)
        beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
            worker_pool, worker_server)
        worker_server.start()
        _LOGGER.info('Listening for workers at %s', worker_address)

        # Register to kill the subprocesses on exit.
        def kill_worker_processes():
            for worker_process in worker_pool._worker_processes.values():
                worker_process.kill()

        atexit.register(kill_worker_processes)

        return worker_address, worker_server
Exemplo n.º 3
0
def main(argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--port',
                        type=int,
                        help='port on which to serve the job api')
    parser.add_argument('--fully_qualified_name_glob', default=None)
    known_args, pipeline_args = parser.parse_known_args(argv)
    pipeline_options = PipelineOptions(
        pipeline_args +
        ["--experiments=beam_fn_api", "--sdk_location=container"])

    with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
            known_args.fully_qualified_name_glob):

        server = grpc.server(thread_pool_executor.shared_unbounded_instance())
        beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
            expansion_service.ExpansionServiceServicer(pipeline_options),
            server)
        beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
            artifact_service.ArtifactRetrievalService(
                artifact_service.BeamFilesystemHandler(None).file_reader),
            server)
        server.add_insecure_port('localhost:{}'.format(known_args.port))
        server.start()
        _LOGGER.info('Listening for expansion requests at %d', known_args.port)

        def cleanup(unused_signum, unused_frame):
            _LOGGER.info('Shutting down expansion service.')
            server.stop(None)

        signal.signal(signal.SIGTERM, cleanup)
        signal.signal(signal.SIGINT, cleanup)
        # blocking main thread forever.
        signal.pause()
Exemplo n.º 4
0
def main(unused_argv):
    PyPIArtifactRegistry.register_artifact('beautifulsoup4', '>=4.9,<5.0')
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--port',
                        type=int,
                        help='port on which to serve the job api')
    parser.add_argument('--fully_qualified_name_glob', default=None)
    options = parser.parse_args()

    global server
    with fully_qualified_named_transform.FullyQualifiedNamedTransform.with_filter(
            options.fully_qualified_name_glob):
        server = grpc.server(thread_pool_executor.shared_unbounded_instance())
        beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
            expansion_service.ExpansionServiceServicer(
                PipelineOptions([
                    "--experiments", "beam_fn_api", "--sdk_location",
                    "container"
                ])), server)
        beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
            artifact_service.ArtifactRetrievalService(
                artifact_service.BeamFilesystemHandler(None).file_reader),
            server)
        server.add_insecure_port('localhost:{}'.format(options.port))
        server.start()
        _LOGGER.info('Listening for expansion requests at %d', options.port)

        signal.signal(signal.SIGTERM, cleanup)
        signal.signal(signal.SIGINT, cleanup)
        # blocking main thread forever.
        signal.pause()
Exemplo n.º 5
0
    def _stage_files(self, files):
        """Utility method to stage files.

      Args:
        files: a list of tuples of the form [(local_name, remote_name),...]
          describing the name of the artifacts in local temp folder and desired
          name in staging location.
    """
        server = grpc.server(thread_pool_executor.shared_unbounded_instance())
        staging_service = TestLocalFileSystemLegacyArtifactStagingServiceServicer(
            self._remote_dir)
        beam_artifact_api_pb2_grpc.add_LegacyArtifactStagingServiceServicer_to_server(
            staging_service, server)
        test_port = server.add_insecure_port('[::]:0')
        server.start()
        stager = portable_stager.PortableStager(
            artifact_service_channel=grpc.insecure_channel('localhost:%s' %
                                                           test_port),
            staging_session_token='token')
        for from_file, to_file in files:
            stager.stage_artifact(local_path_to_artifact=os.path.join(
                self._temp_dir, from_file),
                                  artifact_name=to_file)
        stager.commit_manifest()
        return staging_service.manifest.artifact, staging_service.retrieval_tokens
    def process_bundle(
        self,
        inputs,  # type: Mapping[str, execution.PartitionableBuffer]
        expected_outputs,  # type: DataOutput
        fired_timers,  # type: Mapping[Tuple[str, str], execution.PartitionableBuffer]
        expected_output_timers,  # type: Dict[Tuple[str, str], str]
        dry_run=False,
    ):
        # type: (...) -> BundleProcessResult
        part_inputs = [{} for _ in range(self._num_workers)
                       ]  # type: List[Dict[str, List[bytes]]]
        # Timers are only executed on the first worker
        # TODO(BEAM-9741): Split timers to multiple workers
        timer_inputs = [
            fired_timers if i == 0 else {} for i in range(self._num_workers)
        ]
        for name, input in inputs.items():
            for ix, part in enumerate(input.partition(self._num_workers)):
                part_inputs[ix][name] = part

        merged_result = None  # type: Optional[beam_fn_api_pb2.InstructionResponse]
        split_result_list = [
        ]  # type: List[beam_fn_api_pb2.ProcessBundleSplitResponse]

        def execute(part_map_input_timers):
            # type: (...) -> BundleProcessResult
            part_map, input_timers = part_map_input_timers
            bundle_manager = BundleManager(
                self.bundle_context_manager,
                self._progress_frequency,
                cache_token_generator=self._cache_token_generator)
            return bundle_manager.process_bundle(part_map, expected_outputs,
                                                 input_timers,
                                                 expected_output_timers,
                                                 dry_run)

        with thread_pool_executor.shared_unbounded_instance() as executor:
            for result, split_result in executor.map(
                    execute,
                    zip(
                        part_inputs,  # pylint: disable=zip-builtin-not-iterating
                        timer_inputs)):
                split_result_list += split_result
                if merged_result is None:
                    merged_result = result
                else:
                    merged_result = beam_fn_api_pb2.InstructionResponse(
                        process_bundle=beam_fn_api_pb2.ProcessBundleResponse(
                            monitoring_infos=monitoring_infos.consolidate(
                                itertools.chain(
                                    result.process_bundle.monitoring_infos,
                                    merged_result.process_bundle.
                                    monitoring_infos))),
                        error=result.error or merged_result.error)
        assert merged_result is not None
        return merged_result, split_result_list
Exemplo n.º 7
0
 def setUp(self):
     self.num_request = 3
     self.test_status_service = BeamFnStatusServicer(self.num_request)
     self.server = grpc.server(
         thread_pool_executor.shared_unbounded_instance())
     beam_fn_api_pb2_grpc.add_BeamFnWorkerStatusServicer_to_server(
         self.test_status_service, self.server)
     self.test_port = self.server.add_insecure_port('[::]:0')
     self.server.start()
     self.url = 'localhost:%s' % self.test_port
     self.fn_status_handler = FnApiWorkerStatusHandler(self.url)
Exemplo n.º 8
0
    def start(self):
        if not self._worker_address:
            worker_server = grpc.server(
                thread_pool_executor.shared_unbounded_instance())
            worker_address = 'localhost:%s' % worker_server.add_insecure_port(
                '[::]:0')
            beam_fn_api_pb2_grpc.add_BeamFnExternalWorkerPoolServicer_to_server(
                self, worker_server)
            worker_server.start()

            self._worker_address = worker_address
            atexit.register(functools.partial(worker_server.stop, 1))
        return self._worker_address
Exemplo n.º 9
0
    def get_responses(self, instruction_requests):
        """Evaluates and returns {id: InstructionResponse} for the requests."""
        test_controller = BeamFnControlServicer(instruction_requests)

        server = grpc.server(thread_pool_executor.shared_unbounded_instance())
        beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
            test_controller, server)
        test_port = server.add_insecure_port("[::]:0")
        server.start()

        harness = sdk_worker.SdkHarness("localhost:%s" % test_port,
                                        state_cache_size=100)
        harness.run()
        return test_controller.responses
  def setUp(self):
    self.test_logging_service = BeamFnLoggingServicer()
    self.server = grpc.server(thread_pool_executor.shared_unbounded_instance())
    beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
        self.test_logging_service, self.server)
    self.test_port = self.server.add_insecure_port('[::]:0')
    self.server.start()

    self.logging_service_descriptor = endpoints_pb2.ApiServiceDescriptor()
    self.logging_service_descriptor.url = 'localhost:%s' % self.test_port
    self.fn_log_handler = log_handler.FnApiLogRecordHandler(
        self.logging_service_descriptor)
    logging.getLogger().setLevel(logging.INFO)
    logging.getLogger().addHandler(self.fn_log_handler)
Exemplo n.º 11
0
 def start_grpc_server(self, port=0):
     self._server = grpc.server(
         thread_pool_executor.shared_unbounded_instance())
     port = self._server.add_insecure_port('%s:%d' %
                                           (self.get_bind_address(), port))
     beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(
         self, self._server)
     beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
         self._artifact_service, self._server)
     hostname = self.get_service_address()
     self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
         url='%s:%d' % (hostname, port))
     self._server.start()
     _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port))
     return port
Exemplo n.º 12
0
 def start_grpc_server(self, port=0):
   options = [("grpc.max_receive_message_length", -1),
              ("grpc.max_send_message_length", -1),
              ("grpc.http2.max_pings_without_data", 0),
              ("grpc.http2.max_ping_strikes", 0)]
   self._server = grpc.server(
       thread_pool_executor.shared_unbounded_instance(), options=options)
   port = self._server.add_insecure_port(
       '%s:%d' % (self.get_bind_address(), port))
   beam_job_api_pb2_grpc.add_JobServiceServicer_to_server(self, self._server)
   beam_artifact_api_pb2_grpc.add_ArtifactStagingServiceServicer_to_server(
       self._artifact_service, self._server)
   hostname = self.get_service_address()
   self._artifact_staging_endpoint = endpoints_pb2.ApiServiceDescriptor(
       url='%s:%d' % (hostname, port))
   self._server.start()
   _LOGGER.info('Grpc server started at %s on port %d' % (hostname, port))
   return port
Exemplo n.º 13
0
    def run(self):
        options = [("grpc.http2.max_pings_without_data", 0),
                   ("grpc.http2.max_ping_strikes", 0)]
        logging_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance(), options=options)
        logging_port = logging_server.add_insecure_port('[::]:0')
        logging_server.start()
        logging_servicer = BeamFnLoggingServicer()
        beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
            logging_servicer, logging_server)
        logging_descriptor = text_format.MessageToString(
            endpoints_pb2.ApiServiceDescriptor(url='localhost:%s' %
                                               logging_port))

        control_descriptor = text_format.MessageToString(
            endpoints_pb2.ApiServiceDescriptor(url=self._control_address))
        pipeline_options = json_format.MessageToJson(
            self._provision_info.provision_info.pipeline_options)

        env_dict = dict(os.environ,
                        CONTROL_API_SERVICE_DESCRIPTOR=control_descriptor,
                        LOGGING_API_SERVICE_DESCRIPTOR=logging_descriptor,
                        PIPELINE_OPTIONS=pipeline_options)
        # only add worker_id when it is set.
        if self._worker_id:
            env_dict['WORKER_ID'] = self._worker_id

        with worker_handlers.SUBPROCESS_LOCK:
            p = subprocess.Popen(self._worker_command_line,
                                 shell=True,
                                 env=env_dict)
        try:
            p.wait()
            if p.returncode:
                raise RuntimeError(
                    'Worker subprocess exited with return code %s' %
                    p.returncode)
        finally:
            if p.poll() is None:
                p.kill()
            logging_server.stop(0)
Exemplo n.º 14
0
def main(unused_argv):
    parser = argparse.ArgumentParser()
    parser.add_argument('-p',
                        '--port',
                        type=int,
                        help='port on which to serve the job api')
    options = parser.parse_args()
    global server
    server = grpc.server(thread_pool_executor.shared_unbounded_instance())
    beam_expansion_api_pb2_grpc.add_ExpansionServiceServicer_to_server(
        expansion_service.ExpansionServiceServicer(
            PipelineOptions([
                "--experiments", "beam_fn_api", "--sdk_location", "container"
            ])), server)
    server.add_insecure_port('localhost:{}'.format(options.port))
    server.start()
    _LOGGER.info('Listening for expansion requests at %d', options.port)

    signal.signal(signal.SIGTERM, cleanup)
    signal.signal(signal.SIGINT, cleanup)
    # blocking main thread forever.
    signal.pause()
Exemplo n.º 15
0
    def _check_fn_registration_multi_request(self, *args):
        """Check the function registration calls to the sdk_harness.

    Args:
     tuple of request_count, number of process_bundles per request and workers
     counts to process the request.
    """
        for (request_count, process_bundles_per_request) in args:
            requests = []
            process_bundle_descriptors = []

            for i in range(request_count):
                pbd = self._get_process_bundles(i, process_bundles_per_request)
                process_bundle_descriptors.extend(pbd)
                requests.append(
                    beam_fn_api_pb2.InstructionRequest(
                        instruction_id=str(i),
                        register=beam_fn_api_pb2.RegisterRequest(
                            process_bundle_descriptor=process_bundle_descriptors
                        )))

            test_controller = BeamFnControlServicer(requests)

            server = grpc.server(
                thread_pool_executor.shared_unbounded_instance())
            beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
                test_controller, server)
            test_port = server.add_insecure_port("[::]:0")
            server.start()

            harness = sdk_worker.SdkHarness("localhost:%s" % test_port,
                                            state_cache_size=100)
            harness.run()

            self.assertEqual(
                harness._bundle_processor_cache.fns,
                {item.id: item
                 for item in process_bundle_descriptors})
Exemplo n.º 16
0
    def _grpc_data_channel_test(self, time_based_flush=False):
        if time_based_flush:
            data_servicer = data_plane.BeamFnDataServicer(
                data_buffer_time_limit_ms=100)
        else:
            data_servicer = data_plane.BeamFnDataServicer()
        worker_id = 'worker_0'
        data_channel_service = \
          data_servicer.get_conn_by_worker_id(worker_id)

        server = grpc.server(thread_pool_executor.shared_unbounded_instance())
        beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
            data_servicer, server)
        test_port = server.add_insecure_port('[::]:0')
        server.start()

        grpc_channel = grpc.insecure_channel('localhost:%s' % test_port)
        # Add workerId to the grpc channel
        grpc_channel = grpc.intercept_channel(grpc_channel,
                                              WorkerIdInterceptor(worker_id))
        data_channel_stub = beam_fn_api_pb2_grpc.BeamFnDataStub(grpc_channel)
        if time_based_flush:
            data_channel_client = data_plane.GrpcClientDataChannel(
                data_channel_stub, data_buffer_time_limit_ms=100)
        else:
            data_channel_client = data_plane.GrpcClientDataChannel(
                data_channel_stub)

        try:
            self._data_channel_test(data_channel_service, data_channel_client,
                                    time_based_flush)
        finally:
            data_channel_client.close()
            data_channel_service.close()
            data_channel_client.wait()
            data_channel_service.wait()
Exemplo n.º 17
0
    def __init__(
            self,
            state,  # type: StateServicer
            provision_info,  # type: Optional[ExtendedProvisionInfo]
            worker_manager,  # type: WorkerHandlerManager
    ):
        # type: (...) -> None
        self.state = state
        self.provision_info = provision_info
        self.control_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance())
        self.control_port = self.control_server.add_insecure_port('[::]:0')
        self.control_address = 'localhost:%s' % self.control_port

        # Options to have no limits (-1) on the size of the messages
        # received or sent over the data plane. The actual buffer size
        # is controlled in a layer above.
        no_max_message_sizes = [("grpc.max_receive_message_length", -1),
                                ("grpc.max_send_message_length", -1)]
        self.data_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance(),
            options=no_max_message_sizes)
        self.data_port = self.data_server.add_insecure_port('[::]:0')

        self.state_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance(),
            options=no_max_message_sizes)
        self.state_port = self.state_server.add_insecure_port('[::]:0')

        self.control_handler = BeamFnControlServicer(worker_manager)
        beam_fn_api_pb2_grpc.add_BeamFnControlServicer_to_server(
            self.control_handler, self.control_server)

        # If we have provision info, serve these off the control port as well.
        if self.provision_info:
            if self.provision_info.provision_info:
                beam_provision_api_pb2_grpc.add_ProvisionServiceServicer_to_server(
                    BasicProvisionService(self.provision_info.provision_info,
                                          worker_manager), self.control_server)

            def open_uncompressed(f):
                # type: (str) -> BinaryIO
                return filesystems.FileSystems.open(
                    f, compression_type=CompressionTypes.UNCOMPRESSED)

            beam_artifact_api_pb2_grpc.add_ArtifactRetrievalServiceServicer_to_server(
                artifact_service.ArtifactRetrievalService(
                    file_reader=open_uncompressed), self.control_server)

        self.data_plane_handler = data_plane.BeamFnDataServicer(
            DATA_BUFFER_TIME_LIMIT_MS)
        beam_fn_api_pb2_grpc.add_BeamFnDataServicer_to_server(
            self.data_plane_handler, self.data_server)

        beam_fn_api_pb2_grpc.add_BeamFnStateServicer_to_server(
            GrpcStateServicer(state), self.state_server)

        self.logging_server = grpc.server(
            thread_pool_executor.shared_unbounded_instance(),
            options=no_max_message_sizes)
        self.logging_port = self.logging_server.add_insecure_port('[::]:0')
        beam_fn_api_pb2_grpc.add_BeamFnLoggingServicer_to_server(
            BasicLoggingService(), self.logging_server)

        _LOGGER.info('starting control server on port %s', self.control_port)
        _LOGGER.info('starting data server on port %s', self.data_port)
        _LOGGER.info('starting state server on port %s', self.state_port)
        _LOGGER.info('starting logging server on port %s', self.logging_port)
        self.logging_server.start()
        self.state_server.start()
        self.data_server.start()
        self.control_server.start()
Exemplo n.º 18
0
    def __init__(
            self,
            control_address,  # type: str
            credentials=None,
            worker_id=None,  # type: Optional[str]
            # Caching is disabled by default
        state_cache_size=0,
            # time-based data buffering is disabled by default
            data_buffer_time_limit_ms=0,
            profiler_factory=None,  # type: Optional[Callable[..., Profile]]
            status_address=None,  # type: Optional[str]
    ):
        self._alive = True
        self._worker_index = 0
        self._worker_id = worker_id
        self._state_cache = StateCache(state_cache_size)
        if credentials is None:
            _LOGGER.info('Creating insecure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.insecure_channel(
                control_address)
        else:
            _LOGGER.info('Creating secure control channel for %s.',
                         control_address)
            self._control_channel = GRPCChannelFactory.secure_channel(
                control_address, credentials)
        grpc.channel_ready_future(self._control_channel).result(timeout=60)
        _LOGGER.info('Control channel established.')

        self._control_channel = grpc.intercept_channel(
            self._control_channel, WorkerIdInterceptor(self._worker_id))
        self._data_channel_factory = data_plane.GrpcClientDataChannelFactory(
            credentials, self._worker_id, data_buffer_time_limit_ms)
        self._state_handler_factory = GrpcStateHandlerFactory(
            self._state_cache, credentials)
        self._profiler_factory = profiler_factory
        self._fns = KeyedDefaultDict(
            lambda id: self._control_stub.GetProcessBundleDescriptor(
                beam_fn_api_pb2.GetProcessBundleDescriptorRequest(
                    process_bundle_descriptor_id=id)
            ))  # type: Mapping[str, beam_fn_api_pb2.ProcessBundleDescriptor]
        # BundleProcessor cache across all workers.
        self._bundle_processor_cache = BundleProcessorCache(
            state_handler_factory=self._state_handler_factory,
            data_channel_factory=self._data_channel_factory,
            fns=self._fns)

        if status_address:
            try:
                self._status_handler = FnApiWorkerStatusHandler(
                    status_address, self._bundle_processor_cache
                )  # type: Optional[FnApiWorkerStatusHandler]
            except Exception:
                traceback_string = traceback.format_exc()
                _LOGGER.warning(
                    'Error creating worker status request handler, '
                    'skipping status report. Trace back: %s' %
                    traceback_string)
        else:
            self._status_handler = None

        # TODO(BEAM-8998) use common
        # thread_pool_executor.shared_unbounded_instance() to process bundle
        # progress once dataflow runner's excessive progress polling is removed.
        self._report_progress_executor = futures.ThreadPoolExecutor(
            max_workers=1)
        self._worker_thread_pool = thread_pool_executor.shared_unbounded_instance(
        )
        self._responses = queue.Queue(
        )  # type: queue.Queue[beam_fn_api_pb2.InstructionResponse]
        _LOGGER.info(
            'Initializing SDKHarness with unbounded number of workers.')