コード例 #1
0
ファイル: __init__.py プロジェクト: zpcalan/horovod
def task_exec(driver_addresses, settings, rank_env, local_rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(), ))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    local_rank = int(os.environ[local_rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses,
                                                     key,
                                                     verbose=settings.verbose)

    # tell driver about local rank and rank
    # in elastic mode the driver already knows this mapping
    # for simplicity we keep code paths the same for elastic and static mode
    host_hash = os.environ['HOROVOD_HOSTNAME']
    task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank,
                                                      rank)

    # gather available resources from task service
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index,
                                               task_addresses,
                                               key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
コード例 #2
0
def task_exec(driver_addresses, settings, rank_env):
    # Die if parent process terminates
    in_thread(target=_parent_process_monitor, args=(os.getppid(), ))

    key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY])
    rank = int(os.environ[rank_env])
    driver_client = driver_service.SparkDriverClient(driver_addresses,
                                                     key,
                                                     verbose=settings.verbose)
    task_index = driver_client.task_index_by_rank(rank)
    task_addresses = driver_client.all_task_addresses(task_index)
    task_client = task_service.SparkTaskClient(task_index,
                                               task_addresses,
                                               key,
                                               verbose=settings.verbose)
    task_info.set_resources(task_client.resources())

    fn, args, kwargs = driver_client.code()
    result = fn(*args, **kwargs)
    task_client.register_code_result(result)
コード例 #3
0
ファイル: task_service.py プロジェクト: yunzhezyz/horovod
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        for key, value in req.env.items():
                            if 'SECRET' in key:
                                value = '*' * len(value)
                            print("Task service env: {} = {}".format(
                                key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_thread = in_thread(
                        target=safe_shell_exec.execute,
                        args=(req.command, req.env))
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandTerminatedRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
            finally:
                self._wait_cond.release()
            return CommandTerminatedResponse(terminated)

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)
コード例 #4
0
ファイル: task_service.py プロジェクト: zpcalan/horovod
    def _handle(self, req, client_address):
        if isinstance(req, RunCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is None:
                    # we add req.env to _command_env and make this available to the executed command
                    if self._command_env:
                        env = self._command_env.copy()
                        self._add_envs(env, req.env)
                        req.env = env

                    if self._verbose >= 2:
                        print("Task service executes command: {}".format(
                            req.command))
                        if self._verbose >= 3:
                            for key, value in req.env.items():
                                if 'SECRET' in key:
                                    value = '*' * len(value)
                                print("Task service env: {} = {}".format(
                                    key, value))

                    # We only permit executing exactly one command, so this is idempotent.
                    self._command_abort = threading.Event()
                    self._command_thread = in_thread(
                        target=self._run_command,
                        args=(req.command, req.env, self._command_abort))
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, AbortCommandRequest):
            self._wait_cond.acquire()
            try:
                if self._command_thread is not None:
                    self._command_abort.set()
            finally:
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, NotifyInitialRegistrationCompleteRequest):
            self._wait_cond.acquire()
            try:
                self._initial_registration_complete = True
            finally:
                self._wait_cond.notify_all()
                self._wait_cond.release()
            return network.AckResponse()

        if isinstance(req, CommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                terminated = (self._command_thread is not None
                              and not self._command_thread.is_alive())
                return CommandExitCodeResponse(
                    terminated,
                    self._command_exit_code if terminated else None)
            finally:
                self._wait_cond.release()

        if isinstance(req, WaitForCommandExitCodeRequest):
            self._wait_cond.acquire()
            try:
                while self._command_thread is None or self._command_thread.is_alive(
                ):
                    self._wait_cond.wait(
                        max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY))
                return WaitForCommandExitCodeResponse(self._command_exit_code)
            finally:
                self._wait_cond.release()

        if isinstance(req, RegisterCodeResultRequest):
            self._fn_result = req.result
            return network.AckResponse()

        return super(BasicTaskService, self)._handle(req, client_address)