def task_exec(driver_addresses, settings, rank_env, local_rank_env): # Die if parent process terminates in_thread(target=_parent_process_monitor, args=(os.getppid(), )) key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY]) rank = int(os.environ[rank_env]) local_rank = int(os.environ[local_rank_env]) driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=settings.verbose) # tell driver about local rank and rank # in elastic mode the driver already knows this mapping # for simplicity we keep code paths the same for elastic and static mode host_hash = os.environ['HOROVOD_HOSTNAME'] task_index = driver_client.set_local_rank_to_rank(host_hash, local_rank, rank) # gather available resources from task service task_addresses = driver_client.all_task_addresses(task_index) task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=settings.verbose) task_info.set_resources(task_client.resources()) fn, args, kwargs = driver_client.code() result = fn(*args, **kwargs) task_client.register_code_result(result)
def task_exec(driver_addresses, settings, rank_env): # Die if parent process terminates in_thread(target=_parent_process_monitor, args=(os.getppid(), )) key = codec.loads_base64(os.environ[secret.HOROVOD_SECRET_KEY]) rank = int(os.environ[rank_env]) driver_client = driver_service.SparkDriverClient(driver_addresses, key, verbose=settings.verbose) task_index = driver_client.task_index_by_rank(rank) task_addresses = driver_client.all_task_addresses(task_index) task_client = task_service.SparkTaskClient(task_index, task_addresses, key, verbose=settings.verbose) task_info.set_resources(task_client.resources()) fn, args, kwargs = driver_client.code() result = fn(*args, **kwargs) task_client.register_code_result(result)
def _handle(self, req, client_address): if isinstance(req, RunCommandRequest): self._wait_cond.acquire() try: if self._command_thread is None: # we add req.env to _command_env and make this available to the executed command if self._command_env: env = self._command_env.copy() self._add_envs(env, req.env) req.env = env if self._verbose >= 2: print("Task service executes command: {}".format( req.command)) for key, value in req.env.items(): if 'SECRET' in key: value = '*' * len(value) print("Task service env: {} = {}".format( key, value)) # We only permit executing exactly one command, so this is idempotent. self._command_thread = in_thread( target=safe_shell_exec.execute, args=(req.command, req.env)) finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, NotifyInitialRegistrationCompleteRequest): self._wait_cond.acquire() try: self._initial_registration_complete = True finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, CommandTerminatedRequest): self._wait_cond.acquire() try: terminated = (self._command_thread is not None and not self._command_thread.is_alive()) finally: self._wait_cond.release() return CommandTerminatedResponse(terminated) if isinstance(req, RegisterCodeResultRequest): self._fn_result = req.result return network.AckResponse() return super(BasicTaskService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, RunCommandRequest): self._wait_cond.acquire() try: if self._command_thread is None: # we add req.env to _command_env and make this available to the executed command if self._command_env: env = self._command_env.copy() self._add_envs(env, req.env) req.env = env if self._verbose >= 2: print("Task service executes command: {}".format( req.command)) if self._verbose >= 3: for key, value in req.env.items(): if 'SECRET' in key: value = '*' * len(value) print("Task service env: {} = {}".format( key, value)) # We only permit executing exactly one command, so this is idempotent. self._command_abort = threading.Event() self._command_thread = in_thread( target=self._run_command, args=(req.command, req.env, self._command_abort)) finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, AbortCommandRequest): self._wait_cond.acquire() try: if self._command_thread is not None: self._command_abort.set() finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, NotifyInitialRegistrationCompleteRequest): self._wait_cond.acquire() try: self._initial_registration_complete = True finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, CommandExitCodeRequest): self._wait_cond.acquire() try: terminated = (self._command_thread is not None and not self._command_thread.is_alive()) return CommandExitCodeResponse( terminated, self._command_exit_code if terminated else None) finally: self._wait_cond.release() if isinstance(req, WaitForCommandExitCodeRequest): self._wait_cond.acquire() try: while self._command_thread is None or self._command_thread.is_alive( ): self._wait_cond.wait( max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY)) return WaitForCommandExitCodeResponse(self._command_exit_code) finally: self._wait_cond.release() if isinstance(req, RegisterCodeResultRequest): self._fn_result = req.result return network.AckResponse() return super(BasicTaskService, self)._handle(req, client_address)