Exemple #1
0
    def _handle(self, req, client_address):
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:
                assert 0 <= req.index < self._num_proc
                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, RegisterTaskToTaskAddressesRequest):
            self._wait_cond.acquire()
            try:
                assert 0 <= req.index < self._num_proc
                self._task_addresses_for_tasks[req.index] = req.task_addresses
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, AllTaskAddressesRequest):
            return AllTaskAddressesResponse(
                self._all_task_addresses[req.index])

        return super(BasicDriverService, self)._handle(req, client_address)
Exemple #2
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        for key, value in req.env.items():
                            if 'SECRET' in key:
                                value = '*' * len(value)
                            print("Task service env: {} = {}".format(
                                key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_thread = threading.Thread(
                        target=safe_shell_exec.execute,
                        args=(req.command, req.env))
                    self._command_thread.daemon = True
                    self._command_thread.start()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandTerminatedRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
            finally:
                self._wait_cond.release()
            return CommandTerminatedResponse(terminated)

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
Exemple #3
0
    def _handle(self, req, client_address):
        if isinstance(req, RegisterTaskRequest):
            self._wait_cond.acquire()
            try:
                assert 0 <= req.index < self._num_proc
                self._all_task_addresses[req.index] = req.task_addresses
                # Just use source address for service for fast probing.
                self._task_addresses_for_driver[req.index] = \
                    self._filter_by_ip(req.task_addresses, client_address[0])
                if not self._task_addresses_for_driver[req.index]:
                    # No match is possible if one of the servers is behind NAT.
                    # We don't throw exception here, but will allow the following
                    # code fail with NoValidAddressesFound.
                    print(
                        'ERROR: Task {index} declared addresses {task_addresses}, '
                        'but has connected from a different address {source}. '
                        'This is not supported. Is the server behind NAT?'
                        ''.format(index=req.index,
                                  task_addresses=req.task_addresses,
                                  source=client_address[0]))

                # Remove host hash earlier registered under this index.
                if req.index in self._task_index_host_hash:
                    earlier_host_hash = self._task_index_host_hash[req.index]
                    if earlier_host_hash != req.host_hash:
                        self._task_host_hash_indices[earlier_host_hash].remove(
                            req.index)

                # Make index -> host hash map.
                self._task_index_host_hash[req.index] = req.host_hash

                # Make host hash -> indices map.
                if req.host_hash not in self._task_host_hash_indices:
                    self._task_host_hash_indices[req.host_hash] = []
                self._task_host_hash_indices[req.host_hash].append(req.index)
                # TODO: this sorting is a problem in elastic horovod
                self._task_host_hash_indices[req.host_hash].sort()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, RegisterTaskToTaskAddressesRequest):
            self.register_task_to_task_addresses(req.index, req.task_addresses)
            return network.AckResponse()

        if isinstance(req, AllTaskAddressesRequest):
            return AllTaskAddressesResponse(
                self._all_task_addresses[req.index])

        return super(BasicDriverService, self)._handle(req, client_address)
Exemple #4
0
    def _handle(self, req, client_address):
        if isinstance(req, HostsUpdatedRequest):
            self._manager.handle_hosts_updated(req.timestamp)
            return network.AckResponse()

        return super(WorkerNotificationService,
                     self)._handle(req, client_address)
Exemple #5
0
    def _handle(self, req, client_address):
        if isinstance(req, SleepRequest):
            print('{}: sleeping for client {}'.format(time.time(), client_address))
            time.sleep(self._duration)
            return network.AckResponse()

        return super(TestSleepService, self)._handle(req, client_address)
Exemple #6
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we inject all these environment variables
                    # to make them available to the executed command
                    # NOTE: this will overwrite environment variables that exist in req.env
                    for key in self._service_env_keys:
                        value = os.environ.get(key)
                        if value is not None:
                            req.env[key] = value

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_thread = threading.Thread(
                        target=safe_shell_exec.execute,
                        args=(req.command, req.env))
                    self._command_thread.daemon = True
                    self._command_thread.start()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandTerminatedRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
            finally:
                self._wait_cond.release()
            return CommandTerminatedResponse(terminated)

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
Exemple #7
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_thread = threading.Thread(
                        target=safe_shell_exec.execute,
                        args=(req.command, req.env))
                    self._command_thread.daemon = True
                    self._command_thread.start()
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandTerminatedRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
            finally:
                self._wait_cond.release()
            return CommandTerminatedResponse(terminated)

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
Exemple #8
0
    def _handle(self, req, client_address):

        if isinstance(req, TaskHostHashIndicesRequest):
            return TaskHostHashIndicesResponse(
                self._task_host_hash_indices[req.host_hash])

        if isinstance(req, SetLocalRankToRankRequest):
            self._lock.acquire()

            try:
                # get index for host and local_rank
                indices = self._task_host_hash_indices[req.host]
                index = indices[req.local_rank]

                # remove earlier rank for this index
                # dict.keys() and dict.values() have corresponding order
                # so we look up index in _ranks_to_indices.values() and use that position
                # to get the corresponding key (the rank) from _ranks_to_indices.keys()
                # https://stackoverflow.com/questions/835092/python-dictionary-are-keys-and-values-always-the-same-order
                values = list(self._ranks_to_indices.values())
                prev_pos = values.index(index) if index in values else None
                if prev_pos is not None:
                    prev_rank = list(self._ranks_to_indices.keys())[prev_pos]
                    del self._ranks_to_indices[prev_rank]

                # memorize rank's index
                self._ranks_to_indices[req.rank] = index
            finally:
                self._lock.release()
            return SetLocalRankToRankResponse(index)

        if isinstance(req, TaskIndexByRankRequest):
            self._lock.acquire()
            try:
                return TaskIndexByRankResponse(
                    self._ranks_to_indices[req.rank])
            finally:
                self._lock.release()

        if isinstance(req, CodeRequest):
            return CodeResponse(self._fn, self._args, self._kwargs)

        if isinstance(req, WaitForTaskShutdownRequest):
            self._task_shutdown.wait()
            return network.AckResponse()

        return super(SparkDriverService, self)._handle(req, client_address)
Exemple #9
0
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_thread = in_thread(
                        target=self._run_command,
                        args=(req.command, req.env, self._command_abort))
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)