コード例 #1
0
    def _handle(self, req, client_address):
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:
                assert 0 <= req.index < self._num_proc
                print("Setting address from RegisterTaskRequest: " +
                      str(req.task_addresses))
                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                if not self._task_addresses_for_driver[req.index]:
                    # No match is possible if one of the servers is behind NAT.
                    # We don't throw exception here, but will allow the following
                    # code fail with NoValidAddressesFound.
                    print(
                        'ERROR: Task {index} declared addresses {task_addresses}, '
                        'but has connected from a different address {source}. '
                        'This is not supported. Is the server behind NAT?'
                        ''.format(index=req.index,
                                  task_addresses=req.task_addresses,
                                  source=client_address[0]))

                # Remove host hash earlier registered under this index.
                if req.index in self._task_index_host_hash:
                    earlier_host_hash = self._task_index_host_hash[req.index]
                    if earlier_host_hash != req.host_hash:
                        self._task_host_hash_indices[earlier_host_hash].remove(
                            req.index)

                # Make index -> host hash map.
                self._task_index_host_hash[req.index] = req.host_hash

                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                # TODO: this sorting is a problem in elastic horovod
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, RegisterTaskToTaskAddressesRequest):
            print("Task being registered to task address: " +
                  str(req.task_addresses))
            self.register_task_to_task_addresses(req.index, req.task_addresses)
            return network.AckResponse()

        if isinstance(req, AllTaskAddressesRequest):
            return AllTaskAddressesResponse(
                self._all_task_addresses[req.index])

        return super(BasicDriverService, self)._handle(req, client_address)
コード例 #2
0
ファイル: test_service.py プロジェクト: zw0610/horovod
    def _handle(self, req, client_address):
        if isinstance(req, SleepRequest):
            print('{}: sleeping for client {}'.format(time.time(), client_address))
            time.sleep(self._duration)
            return network.AckResponse()

        return super(TestSleepService, self)._handle(req, client_address)
コード例 #3
0
ファイル: worker.py プロジェクト: zw0610/horovod
    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp)
            return network.AckResponse()

        return super(WorkerNotificationService,
                     self)._handle(req, client_address)
コード例 #4
0
    def _handle(self, req, client_address):

        if isinstance(req, TaskHostHashIndicesRequest):
            return TaskHostHashIndicesResponse(
                self._task_host_hash_indices[req.host_hash])

        if isinstance(req, SetLocalRankToRankRequest):
            self._lock.acquire()

            try:
                # get index for host and local_rank
                indices = self._task_host_hash_indices[req.host]
                index = indices[req.local_rank]

                # remove earlier rank for this index
                # dict.keys() and dict.values() have corresponding order
                # so we look up index in _ranks_to_indices.values() and use that position
                # to get the corresponding key (the rank) from _ranks_to_indices.keys()
                # https://stackoverflow.com/questions/835092/python-dictionary-are-keys-and-values-always-the-same-order
                values = list(self._ranks_to_indices.values())
                prev_pos = values.index(index) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
                    del self._ranks_to_indices[prev_rank]

                # memorize rank's index
                self._ranks_to_indices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(index)

        if isinstance(req, TaskIndexByRankRequest):
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(
                    self._ranks_to_indices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, CodeRequest):
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownRequest):
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverService, self)._handle(req, client_address)
コード例 #5
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_thread = in_thread(
                        target=self._run_command,
                        args=(req.command, req.env, self._command_abort))
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
コード例 #6
0
ファイル: task_service.py プロジェクト: zyx1213271098/horovod
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_stdout = Pipe(
                    ) if req.capture_stdout else None
                    self._command_stderr = Pipe(
                    ) if req.capture_stderr else None
                    args = (req.command, req.env, self._command_abort,
                            self._command_stdout, self._command_stderr,
                            self._index, req.prefix_output_with_timestamp)
                    self._command_thread = in_thread(self._run_command, args)
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, StreamCommandOutputRequest):
            # Wait for command to start
            self.wait_for_command_start()

            # We only expect streaming each command output stream once concurrently
            if isinstance(req, StreamCommandStdOutRequest):
                return self.stream_output(self._command_stdout)
            elif isinstance(req, StreamCommandStdErrRequest):
                return self.stream_output(self._command_stderr)
            else:
                return CommandOutputNotCaptured()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
                if self._command_stdout is not None:
                    self._command_stdout.close()
                if self._command_stderr is not None:
                    self._command_stderr.close()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
コード例 #7
0
ファイル: compute_service.py プロジェクト: chongxiaoc/horovod
    def _handle(self, req, client_address):
        if isinstance(req, RegisterDispatcherRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                if self._dispatcher_addresses[req.dispatcher_id] is not None and \
                   self._dispatcher_addresses[req.dispatcher_id] != req.dispatcher_address:
                    return ValueError(
                        f'Dispatcher with id {req.dispatcher_id} has already been registered under '
                        f'different address {self._dispatcher_addresses[req.dispatcher_id]}: '
                        f'{req.dispatcher_address}')

                self._dispatcher_addresses[
                    req.dispatcher_id] = req.dispatcher_address
                self._wait_cond.notify_all()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, WaitForDispatcherRegistrationRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                tmout = timeout.Timeout(
                    timeout=req.timeout,
                    message=
                    'Timed out waiting for {activity}. Try to find out what takes '
                    'the dispatcher so long to register or increase timeout.')

                while self._dispatcher_addresses[req.dispatcher_id] is None:
                    self._wait_cond.wait(tmout.remaining())
                    tmout.check_time_out_for(
                        f'dispatcher {req.dispatcher_id} to register')
            except TimeoutException as e:
                return e
            finally:
                self._wait_cond.release()
            return WaitForDispatcherRegistrationResponse(
                self._dispatcher_addresses[req.dispatcher_id])

        if isinstance(req, RegisterDispatcherWorkerRequest):
            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                self._dispatcher_worker_ids[req.dispatcher_id].update(
                    {req.worker_id})
                self._wait_cond.notify_all()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, WaitForDispatcherWorkerRegistrationRequest):
            # if there is only a single dispatcher, wait for that one instead of the requested one
            dispatcher_id = req.dispatcher_id if self._max_dispatcher_id > 0 else 0

            self._wait_cond.acquire()
            try:
                if not 0 <= req.dispatcher_id <= self._max_dispatcher_id:
                    return IndexError(
                        f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: '
                        f'{req.dispatcher_id}')

                tmout = timeout.Timeout(
                    timeout=req.timeout,
                    message=
                    'Timed out waiting for {activity}. Try to find out what takes '
                    'the workers so long to register or increase timeout.')

                while len(self._dispatcher_worker_ids[dispatcher_id]
                          ) < self._workers_per_dispatcher:
                    self._wait_cond.wait(tmout.remaining())
                    tmout.check_time_out_for(
                        f'workers for dispatcher {dispatcher_id} to register')
            except TimeoutException as e:
                return e
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, ShutdownRequest):
            in_thread(self.shutdown)
            return network.AckResponse()

        if isinstance(req, WaitForShutdownRequest):
            self._wait_cond.acquire()
            try:
                while not self._shutdown:
                    self._wait_cond.wait()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        return super()._handle(req, client_address)