Ejemplo n.º 1
0
def _notify_and_register_task_addresses(driver, settings, notify=True):
    # wait for num_proc tasks to register
    driver.wait_for_initial_registration(settings.start_timeout)
    if settings.verbose >= 2:
        logging.info('Initial Spark task registration is complete.')

    task_indices = driver.task_indices()
    task_pairs = zip(task_indices, task_indices[1:] + task_indices[0:1])

    def notify_and_register(task_index, next_task_index):
        task_client = task_service.SparkTaskClient(
            task_index, driver.task_addresses_for_driver(task_index),
            settings.key, settings.verbose)

        if notify:
            task_client.notify_initial_registration_complete()

        next_task_addresses = driver.all_task_addresses(next_task_index)
        task_to_task_addresses = task_client.get_task_addresses_for_task(
            next_task_index, next_task_addresses)
        driver.register_task_to_task_addresses(next_task_index,
                                               task_to_task_addresses)

    for task_index, next_task_index in task_pairs:
        in_thread(notify_and_register, (task_index, next_task_index))

    driver.wait_for_task_to_task_address_updates(settings.start_timeout)

    if settings.verbose >= 2:
        logging.info('Spark task-to-task address registration is complete.')
Ejemplo n.º 2
0
    def test_in_thread_args(self):
        fn = mock.Mock()
        thread = in_thread(fn, args=(1, ))
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1)

        fn = mock.Mock()
        thread = in_thread(fn, args=(1, 2))
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1, 2)

        fn = mock.Mock()
        thread = in_thread(fn, args=(1, 2), silent=True)
        thread.join(1.0)
        self.assertFalse(thread.is_alive())
        fn.assert_called_once_with(1, 2)

        fn = mock.Mock()
        with pytest.raises(
                ValueError,
                match="^args must be a tuple, not <(class|type) 'int'>, "
                "for a single argument use \\(arg,\\)$"):
            in_thread(fn, args=1)
        fn.assert_not_called()
Ejemplo n.º 3
0
    def provide_hosts(self, hosts):
        """Makes the cluster provide the given hosts only.

        Any host currently provided that is not in the given hosts will be shut down.
        This does not allow for changes in the number of slots.
        """
        logging.debug('make Spark cluster provide hosts %s', hosts)

        # shut down missing works first
        for host in self._host_worker.copy():
            if host not in hosts:
                in_thread(self.stop_worker, args=(self._host_worker[host], ))
                del self._host_worker[host]

        # start new workers
        threads = []
        for host in hosts:
            if host not in self._host_worker:
                cores = int(host.split(':', 1)[1])
                instance = self._next_worker_instance
                threads.append(
                    in_thread(self.start_or_restart_worker,
                              args=(instance, cores)))
                self._host_worker[host] = instance
                self._next_worker_instance += 1
        for thread in threads:
            thread.join(5)
Ejemplo n.º 4
0
    def stream_command_output(self, stdout=None, stderr=None):
        def send(req, stream):
            try:
                self._send(req, stream)
            except Exception as e:
                self.abort_command()
                raise e

        return (in_thread(send, (StreamCommandStdOutRequest(),
                                 stdout)) if stdout else None,
                in_thread(send, (StreamCommandStdErrRequest(),
                                 stderr)) if stderr else None)
Ejemplo n.º 5
0
    def _handle(self, req, client_address):
        if isinstance(req, SleepRequest):
            pipe = Pipe()

            def sleep():
                time.sleep(self._duration)
                pipe.write('slept {}'.format(self._duration))
                pipe.close()

            in_thread(sleep)

            return network.AckStreamResponse(), pipe

        return super(TestStreamService, self)._handle(req, client_address)
Ejemplo n.º 6
0
    def wait_for_dispatcher(client, dispatcher_id, queue):
        def _wait():
            queue.put(
                (dispatcher_id,
                 client.wait_for_dispatcher_registration(dispatcher_id, 10)))

        return in_thread(_wait, daemon=True)
Ejemplo n.º 7
0
def _make_spark_thread(spark_context, spark_job_group, driver, result_queue,
                       settings, use_gloo, is_elastic):
    """Creates `settings.num_proc` Spark tasks in a parallel thread."""
    def run_spark():
        """Creates `settings.num_proc` Spark tasks, each executing `_task_fn` and waits for them to terminate."""
        try:
            spark_context.setJobGroup(spark_job_group,
                                      "Horovod Spark Run",
                                      interruptOnCancel=True)
            procs = spark_context.range(
                0,
                numSlices=settings.max_num_proc
                if settings.elastic else settings.num_proc)
            # We assume that folks caring about security will enable Spark RPC encryption,
            # thus ensuring that key that is passed here remains secret.
            mapper = _make_mapper(driver.addresses(), settings, use_gloo,
                                  is_elastic)
            result = procs.mapPartitionsWithIndex(mapper).collect()
            result_queue.put(result)
        except:
            driver.notify_spark_job_failed()
            raise

    spark_thread = in_thread(target=run_spark, daemon=False)
    return spark_thread
Ejemplo n.º 8
0
def delay(func, seconds):
    """Delays the execution of func in a separate thread by given seconds."""
    def fn():
        time.sleep(seconds)
        func()

    return in_thread(target=fn)
Ejemplo n.º 9
0
    def test_shutdown_during_request_basic(self):
        sleep = 2.0
        key = secret.make_secret_key()
        service = TestSleepService(key, duration=sleep)
        try:
            client = TestSleepClient(service.addresses(), key, attempts=1)
            start = time.time()
            threads = list([
                in_thread(client.sleep,
                          name='request {}'.format(i + 1),
                          daemon=False) for i in range(5)
            ])
            time.sleep(sleep / 2.0)
        finally:
            service.shutdown()

        duration = time.time() - start
        print('shutdown completed in {} seconds'.format(duration))
        self.assertGreaterEqual(duration, sleep,
                                'sleep requests should have been completed')
        self.assertLess(duration, sleep + 1.0,
                        'sleep requests should have been concurrent')

        for thread in threads:
            thread.join(0.1)
            self.assertFalse(thread.is_alive(),
                             'thread should have terminated by now')
Ejemplo n.º 10
0
    def start(self, handler_cls=RendezvousHandler):
        self._httpd, port = find_port(lambda addr: RendezvousHTTPServer(
            addr, handler_cls, self._verbose))
        if self._verbose:
            logging.info('Rendezvous INFO: HTTP rendezvous server started.')

        # start the listening loop
        self._listen_thread = in_thread(target=self._httpd.serve_forever)

        return port
Ejemplo n.º 11
0
    def start(self):
        logging.info('starting Spark cluster')
        if not self.start_master():
            self.stop_master()
            if not self.start_master():
                raise RuntimeError('could not start master')

        self._logfile_monitor = in_thread(self._monitor_logfile)
        if self._hosts:
            self.provide_hosts(self._hosts)
Ejemplo n.º 12
0
 def __init__(self, service_name, key, nics):
     self._service_name = service_name
     self._wire = Wire(key)
     self._nics = nics
     self._server, _ = find_port(
         lambda addr: socketserver.ThreadingTCPServer(
             addr, self._make_handler()))
     self._server._block_on_close = True
     self._port = self._server.socket.getsockname()[1]
     self._addresses = self._get_local_addresses()
     self._thread = in_thread(target=self._server.serve_forever)
Ejemplo n.º 13
0
    def start_server(self):
        self.httpd, port = find_port(
            lambda addr: KVStoreHTTPServer(
                addr, KVStoreHandler, self.verbose))

        self.listen_thread = in_thread(target=self.httpd.serve_forever)

        if self.verbose:
            logging.info('KVStoreServer INFO: KVStore server started. Listen on port ' + str(port))

        return port
Ejemplo n.º 14
0
def _exec_middleman(command, env, exit_event, stdout, stderr, rw):
    stdout_r, stdout_w = stdout
    stderr_r, stderr_w = stderr
    r, w = rw

    # Close unused file descriptors to enforce PIPE behavior.
    stdout_r.close()
    stderr_r.close()
    w.close()
    os.setsid()

    executor_shell = subprocess.Popen(command,
                                      shell=True,
                                      env=env,
                                      stdout=stdout_w,
                                      stderr=stderr_w)

    # we don't bother stopping the on_event thread, this process sys.exits soon
    # so the on_event thread has to be a deamon thread
    on_event(exit_event,
             terminate_executor_shell_and_children,
             args=(executor_shell.pid, ),
             daemon=True)

    def kill_executor_children_if_parent_dies():
        # This read blocks until the pipe is closed on the other side
        # due to parent process termination (for any reason, including -9).
        os.read(r.fileno(), 1)
        terminate_executor_shell_and_children(executor_shell.pid)

    in_thread(kill_executor_children_if_parent_dies)

    exit_code = executor_shell.wait()
    if exit_code < 0:
        # See: https://www.gnu.org/software/bash/manual/html_node/Exit-Status.html
        exit_code = 128 + abs(exit_code)

    sys.exit(exit_code)
Ejemplo n.º 15
0
    def test_concurrent_requests_basic(self):
        sleep = 2.0
        key = secret.make_secret_key()
        service = TestSleepService(key, duration=sleep)
        client = TestSleepClient(service.addresses(), key, attempts=1)

        start = time.time()
        threads = list([in_thread(client.sleep, daemon=False) for _ in range(1)])
        for thread in threads:
            thread.join(sleep + 1.0)
            self.assertFalse(thread.is_alive(), 'thread should have terminated by now')
        duration = time.time() - start
        print('concurrent requests completed in {} seconds'.format(duration))

        self.assertGreaterEqual(duration, sleep, 'sleep requests should have been completed')
        self.assertLess(duration, sleep + 1.0, 'sleep requests should have been concurrent')
Ejemplo n.º 16
0
    def _probe(self, addresses):
        result_queue = queue.Queue()
        threads = []
        for intf, intf_addresses in addresses.items():
            for addr in intf_addresses:
                thread = in_thread(target=self._probe_one, args=(intf, addr, result_queue))
                threads.append(thread)
        for t in threads:
            t.join()

        result = {}
        while not result_queue.empty():
            intf, addr = result_queue.get()
            if intf not in result:
                result[intf] = []
            result[intf].append(addr)
        return result
Ejemplo n.º 17
0
def _task_fn(index, driver_addresses, key, settings, use_gloo, is_elastic):
    # deserialized on Spark workers, settings do not contain the key, so it is given here explicitly
    # Spark RPC communicates the key and supports encryption
    # for convenience, we put it back into settings
    settings.key = key

    # to simplify things, each task is an individual host in Elastic Horovod on Spark
    # further, each attempt (instance) of a task is an individual host in Elastic Horovod on Spark
    # hides availability of shared memory among executors on the same Spark node
    hosthash = host_hash(
        salt='{}-{}'.format(index, time.time()) if is_elastic else None)

    # provide host hash to mpirun_exec_fn.py via task service
    # gloo_exec_fn.py will get this env var set in request env as well
    os.environ['HOROVOD_HOSTNAME'] = hosthash

    task = task_service.SparkTaskService(
        index, settings.key, settings.nics,
        MINIMUM_COMMAND_LIFETIME_S if is_elastic or use_gloo else None,
        settings.verbose)
    try:
        driver_client = driver_service.SparkDriverClient(
            driver_addresses, settings.key, settings.verbose)
        driver_client.register_task(index, task.addresses(), hosthash)

        if not is_elastic:
            task.wait_for_initial_registration(settings.start_timeout)
            task_indices_on_this_host = driver_client.task_host_hash_indices(
                hosthash)
            local_rank_zero_index = task_indices_on_this_host[0]
        else:
            local_rank_zero_index = None

        # In elastic all tasks wait for task shutdown signal from driver.
        # With Gloo all tasks wait for the command to start and terminate.
        # With MPI task with first index executes orted which will run mpirun_exec_fn for all tasks.
        if is_elastic:
            # either terminate on task shutdown or command termination
            shutdown_thread = in_thread(driver_client.wait_for_task_shutdown)

            while shutdown_thread.is_alive():
                # Once the command started we wait for its termination
                if task.check_for_command_start(
                        WAIT_FOR_COMMAND_START_DELAY_SECONDS):
                    task.wait_for_command_termination()
                    if task.command_exit_code() != 0:
                        raise Exception(
                            'Command failed, making Spark task fail to restart the task'
                        )
                    break

                # While no command started, we can shutdown any time
                shutdown_thread.join(WAIT_FOR_SHUTDOWN_DELAY_SECONDS)
        elif use_gloo or index == local_rank_zero_index:
            # Either Gloo or first task with MPI.
            task.wait_for_command_start(settings.start_timeout)
            task.wait_for_command_termination()
        else:
            # The other tasks with MPI need to wait for the first task to finish.
            first_task_addresses = driver_client.all_task_addresses(
                local_rank_zero_index)
            first_task_client = \
                task_service.SparkTaskClient(local_rank_zero_index,
                                             first_task_addresses, settings.key,
                                             settings.verbose)
            first_task_client.wait_for_command_termination()

        return task.fn_result()
    finally:
        # we must not call into shutdown too quickly, task clients run a command
        # and want to wait on the result, we have told task service not to return
        # from wait_for_command_termination too quickly, so we are safe here to shutdown
        # clients have had enough time to connect to the service already
        #
        # the shutdown has to block on running requests (wait_for_command_exit_code)
        # so they can finish serving the exit code
        # shutdown does block with network.BasicService._server._block_on_close = True
        task.shutdown()
Ejemplo n.º 18
0
    def do_test_worker_compute_side(self, dispatchers: int,
                                    processing_mode: str, reuse_dataset: bool,
                                    round_robin: bool):
        # the config file for this worker
        configfile = __file__ + '.config'
        if self.rank == 0 and os.path.exists(configfile):
            raise RuntimeError(
                f'Config file exists already, please delete first: {configfile}'
            )

        # synchronize with all processes
        self.assertTrue(self.size > 1)
        logging.debug('waiting for all processes to get started')
        cluster_shape = hvd.allgather_object((self.rank, self.size),
                                             name='test_start')
        self.assertEqual(self.expected_cluster_shape, cluster_shape)
        logging.debug('all processes started')

        try:
            # start the worker
            logging.debug('starting worker process')
            worker = in_thread(
                main, (dispatchers, 'compute', configfile, self.timeout),
                daemon=True)
            # this runs 'main' as a separated process
            #command = f'{sys.executable} -m horovod.tensorflow.data.compute_worker --dispatchers {dispatchers} --dispatcher-side compute {configfile}'
            #worker = in_thread(safe_shell_exec.execute, (command, None, sys.stdout, sys.stderr), daemon=True)
            logging.debug('worker process started')

            # read the config file
            compute_config = TfDataServiceConfig.read(
                configfile, wait_for_file_creation=True)

            try:
                # Allow tf.data service to pre-process the pipeline
                dataset = tf.data.Dataset.range(1024)
                if reuse_dataset and round_robin:
                    dataset = dataset.repeat()
                dataset = dataset.batch(128) \
                    .send_to_data_service(compute_config, self.rank, self.size,
                                          processing_mode=processing_mode,
                                          reuse_dataset=reuse_dataset,
                                          round_robin=round_robin)

                # fetch the batches
                it = islice(dataset.as_numpy_iterator(), 8)
                actual = list([batch.tolist() for batch in it])

                # synchronize with all processes
                logging.debug('waiting for all processes to finish')
                actuals = hvd.allgather_object(actual)
                logging.debug('all processes finished')

                # assert the provided batches
                # the batches are not deterministic, so we cannot assert them here too thoroughly
                # that would test tf.data service anyway, all we assert here is that worker and send_to_data_service
                # work together nicely and produce a consumable dataset
                self.assertEqual(self.size,
                                 len(actuals),
                                 msg="one 'actual batches' from each process")

                # in reuse_dataset and fcfs it might happen that one process gets all the data and one does not get any
                if reuse_dataset and not round_robin:
                    self.assertTrue(
                        any([len(actual) > 0 for actual in actuals]),
                        msg='at least one process has at least one batch')
                else:
                    self.assertEqual([True] * self.size,
                                     [len(actual) > 0 for actual in actuals],
                                     msg='each process has at least one batch')

                for actual in actuals:
                    self.assertEqual(
                        [True] * len(actual),
                        [0 < len(batch) <= 128 for batch in actual],
                        msg=
                        f'all batches are at most 128 in size: {[len(batch) for batch in actual]}'
                    )
                    for batch in actual:
                        self.assertEqual(
                            [True] * len(batch),
                            [0 <= i < 1024 for i in batch],
                            msg=
                            f'values in batch must be within [0..1024): {batch}'
                        )

            finally:
                # shutdown compute service
                if self.rank == 0:
                    logging.debug('sending shutdown request')
                    compute = compute_config.compute_client(verbose=2)
                    compute.shutdown()
                    logging.debug('shutdown request sent')

                # in round robin mode, the worker process does not terminate once stopped until some high timeout
                if not (reuse_dataset and round_robin):
                    # wait for the worker to terminate
                    logging.debug('waiting for worker to terminate')
                    worker.join(self.timeout)

                    self.assertFalse(worker.is_alive())
                    logging.debug('worker terminated')

        finally:
            # remove the configfile as it will interfere with subsequent runs of this test
            if self.rank == 0 and os.path.exists(configfile):
                os.unlink(configfile)
Ejemplo n.º 19
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_stdout = Pipe(
                    ) if req.capture_stdout else None
                    self._command_stderr = Pipe(
                    ) if req.capture_stderr else None
                    args = (req.command, req.env, self._command_abort,
                            self._command_stdout, self._command_stderr,
                            self._index, req.prefix_output_with_timestamp)
                    self._command_thread = in_thread(self._run_command, args)
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, StreamCommandOutputRequest):
            # Wait for command to start
            self.wait_for_command_start()

            # We only expect streaming each command output stream once concurrently
            if isinstance(req, StreamCommandStdOutRequest):
                return self.stream_output(self._command_stdout)
            elif isinstance(req, StreamCommandStdErrRequest):
                return self.stream_output(self._command_stderr)
            else:
                return CommandOutputNotCaptured()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
                if self._command_stdout is not None:
                    self._command_stdout.close()
                if self._command_stderr is not None:
                    self._command_stderr.close()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
Ejemplo n.º 20
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_thread = in_thread(
                        target=self._run_command,
                        args=(req.command, req.env, self._command_abort))
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
Ejemplo n.º 21
0
    def wait_for_dispatcher_workers(client, dispatcher_id, queue):
        def _wait():
            client.wait_for_dispatcher_worker_registration(dispatcher_id, 10)
            queue.put(dispatcher_id)

        return in_thread(_wait, daemon=True)
Ejemplo n.º 22
0
 def stop_workers(self):
     for instance in self._workers.copy():
         in_thread(self.stop_worker(instance), daemon=False)
Ejemplo n.º 23
0
    def wait_for_shutdown(client, queue):
        def _wait():
            client.wait_for_shutdown()
            queue.put(True)

        return in_thread(_wait, daemon=True)
Ejemplo n.º 24
0
    def _handle(self, req, client_address):
        if isinstance(req, RegisterDispatcherRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                if self._dispatcher_addresses[req.dispatcher_id] is not None and \
                   self._dispatcher_addresses[req.dispatcher_id] != req.dispatcher_address:
                    return ValueError(
                        f'Dispatcher with id {req.dispatcher_id} has already been registered under '
                        f'different address {self._dispatcher_addresses[req.dispatcher_id]}: '
                        f'{req.dispatcher_address}')

                self._dispatcher_addresses[
                    req.dispatcher_id] = req.dispatcher_address
                self._wait_cond.notify_all()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, WaitForDispatcherRegistrationRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                tmout = timeout.Timeout(
                    timeout=req.timeout,
                    message=
                    'Timed out waiting for {activity}. Try to find out what takes '
                    'the dispatcher so long to register or increase timeout.')

                while self._dispatcher_addresses[req.dispatcher_id] is None:
                    self._wait_cond.wait(tmout.remaining())
                    tmout.check_time_out_for(
                        f'dispatcher {req.dispatcher_id} to register')
            except TimeoutException as e:
                return e
            finally:
                self._wait_cond.release()
            return WaitForDispatcherRegistrationResponse(
                self._dispatcher_addresses[req.dispatcher_id])

        if isinstance(req, RegisterDispatcherWorkerRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                self._dispatcher_worker_ids[req.dispatcher_id].update(
                    {req.worker_id})
                self._wait_cond.notify_all()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, WaitForDispatcherWorkerRegistrationRequest):
            # if there is only a single dispatcher, wait for that one instead of the requested one
            dispatcher_id = req.dispatcher_id if self._max_dispatcher_id > 0 else 0

            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                tmout = timeout.Timeout(
                    timeout=req.timeout,
                    message=
                    'Timed out waiting for {activity}. Try to find out what takes '
                    'the workers so long to register or increase timeout.')

                while len(self._dispatcher_worker_ids[dispatcher_id]
                          ) < self._workers_per_dispatcher:
                    self._wait_cond.wait(tmout.remaining())
                    tmout.check_time_out_for(
                        f'workers for dispatcher {dispatcher_id} to register')
            except TimeoutException as e:
                return e
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, ShutdownRequest):
            in_thread(self.shutdown)
            return network.AckResponse()

        if isinstance(req, WaitForShutdownRequest):
            self._wait_cond.acquire()
            try:
                while not self._shutdown:
                    self._wait_cond.wait()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        return super()._handle(req, client_address)
Ejemplo n.º 25
0
def execute(command,
            env=None,
            stdout=None,
            stderr=None,
            index=None,
            events=None,
            prefix_output_with_timestamp=False):
    ctx = multiprocessing.get_context('spawn')

    # When this event is set, signal to middleman to terminate its children and exit.
    exit_event = _create_event(ctx)

    # Make a pipe for the subprocess stdout/stderr.
    (stdout_r, stdout_w) = ctx.Pipe()
    (stderr_r, stderr_w) = ctx.Pipe()

    # This Pipe is how we ensure that the executed process is properly terminated (not orphaned) if
    # the parent process is hard killed (-9). If the parent (this process) is killed for any reason,
    # this Pipe will be closed, which can be detected by the middleman. When the middleman sees the
    # closed Pipe, it will issue a SIGTERM to the subprocess executing the command. The assumption
    # here is that users will be inclined to hard kill this process, not the middleman.
    (r, w) = ctx.Pipe()

    middleman = ctx.Process(target=_exec_middleman,
                            args=(command, env, exit_event,
                                  (stdout_r, stdout_w), (stderr_r,
                                                         stderr_w), (r, w)))
    middleman.start()

    # Close unused file descriptors to enforce PIPE behavior.
    r.close()
    stdout_w.close()
    stderr_w.close()

    # Redirect command stdout & stderr to provided streams or sys.stdout/sys.stderr.
    # This is useful for Jupyter Notebook that uses custom sys.stdout/sys.stderr or
    # for redirecting to a file on disk.
    if stdout is None:
        stdout = sys.stdout
    if stderr is None:
        stderr = sys.stderr

    stdout_fwd = in_thread(target=forward_stream,
                           args=(stdout_r, stdout, 'stdout', index,
                                 prefix_output_with_timestamp))
    stderr_fwd = in_thread(target=forward_stream,
                           args=(stderr_r, stderr, 'stderr', index,
                                 prefix_output_with_timestamp))

    # TODO: Currently this requires explicitly declaration of the events and signal handler to set
    #  the event (gloo_run.py:_launch_jobs()). Need to figure out a generalized way to hide this behind
    #  interfaces.
    stop = threading.Event()
    events = events or []
    for event in events:
        on_event(event, exit_event.set, stop=stop, silent=True)

    try:
        middleman.join()
    except:
        # interrupted, send middleman TERM signal which will terminate children
        exit_event.set()
        while True:
            try:
                middleman.join()
                break
            except:
                # interrupted, wait for middleman to finish
                pass
    finally:
        stop.set()

    stdout_fwd.join()
    stderr_fwd.join()

    return middleman.exitcode