def _handle(self, req, client_address): if isinstance(req, RegisterTaskRequest): self._wait_cond.acquire() try: assert 0 <= req.index < self._num_proc print("Setting address from RegisterTaskRequest: " + str(req.task_addresses)) self._all_task_addresses[req.index] = req.task_addresses # Just use source address for service for fast probing. self._task_addresses_for_driver[req.index] = \ self._filter_by_ip(req.task_addresses, client_address[0]) if not self._task_addresses_for_driver[req.index]: # No match is possible if one of the servers is behind NAT. # We don't throw exception here, but will allow the following # code fail with NoValidAddressesFound. print( 'ERROR: Task {index} declared addresses {task_addresses}, ' 'but has connected from a different address {source}. ' 'This is not supported. Is the server behind NAT?' ''.format(index=req.index, task_addresses=req.task_addresses, source=client_address[0])) # Remove host hash earlier registered under this index. if req.index in self._task_index_host_hash: earlier_host_hash = self._task_index_host_hash[req.index] if earlier_host_hash != req.host_hash: self._task_host_hash_indices[earlier_host_hash].remove( req.index) # Make index -> host hash map. self._task_index_host_hash[req.index] = req.host_hash # Make host hash -> indices map. if req.host_hash not in self._task_host_hash_indices: self._task_host_hash_indices[req.host_hash] = [] self._task_host_hash_indices[req.host_hash].append(req.index) # TODO: this sorting is a problem in elastic horovod self._task_host_hash_indices[req.host_hash].sort() finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, RegisterTaskToTaskAddressesRequest): print("Task being registered to task address: " + str(req.task_addresses)) self.register_task_to_task_addresses(req.index, req.task_addresses) return network.AckResponse() if isinstance(req, AllTaskAddressesRequest): return AllTaskAddressesResponse( self._all_task_addresses[req.index]) return super(BasicDriverService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, SleepRequest): print('{}: sleeping for client {}'.format(time.time(), client_address)) time.sleep(self._duration) return network.AckResponse() return super(TestSleepService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, HostsUpdatedRequest): self._manager.handle_hosts_updated(req.timestamp) return network.AckResponse() return super(WorkerNotificationService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, TaskHostHashIndicesRequest): return TaskHostHashIndicesResponse( self._task_host_hash_indices[req.host_hash]) if isinstance(req, SetLocalRankToRankRequest): self._lock.acquire() try: # get index for host and local_rank indices = self._task_host_hash_indices[req.host] index = indices[req.local_rank] # remove earlier rank for this index # dict.keys() and dict.values() have corresponding order # so we look up index in _ranks_to_indices.values() and use that position # to get the corresponding key (the rank) from _ranks_to_indices.keys() # https://stackoverflow.com/questions/835092/python-dictionary-are-keys-and-values-always-the-same-order values = list(self._ranks_to_indices.values()) prev_pos = values.index(index) if index in values else None if prev_pos is not None: prev_rank = list(self._ranks_to_indices.keys())[prev_pos] del self._ranks_to_indices[prev_rank] # memorize rank's index self._ranks_to_indices[req.rank] = index finally: self._lock.release() return SetLocalRankToRankResponse(index) if isinstance(req, TaskIndexByRankRequest): self._lock.acquire() try: return TaskIndexByRankResponse( self._ranks_to_indices[req.rank]) finally: self._lock.release() if isinstance(req, CodeRequest): return CodeResponse(self._fn, self._args, self._kwargs) if isinstance(req, WaitForTaskShutdownRequest): self._task_shutdown.wait() return network.AckResponse() return super(SparkDriverService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, RunCommandRequest): self._wait_cond.acquire() try: if self._command_thread is None: # we add req.env to _command_env and make this available to the executed command if self._command_env: env = self._command_env.copy() self._add_envs(env, req.env) req.env = env if self._verbose >= 2: print("Task service executes command: {}".format( req.command)) if self._verbose >= 3: for key, value in req.env.items(): if 'SECRET' in key: value = '*' * len(value) print("Task service env: {} = {}".format( key, value)) # We only permit executing exactly one command, so this is idempotent. self._command_abort = threading.Event() self._command_thread = in_thread( target=self._run_command, args=(req.command, req.env, self._command_abort)) finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, AbortCommandRequest): self._wait_cond.acquire() try: if self._command_thread is not None: self._command_abort.set() finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, NotifyInitialRegistrationCompleteRequest): self._wait_cond.acquire() try: self._initial_registration_complete = True finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, CommandExitCodeRequest): self._wait_cond.acquire() try: terminated = (self._command_thread is not None and not self._command_thread.is_alive()) return CommandExitCodeResponse( terminated, self._command_exit_code if terminated else None) finally: self._wait_cond.release() if isinstance(req, WaitForCommandExitCodeRequest): self._wait_cond.acquire() try: while self._command_thread is None or self._command_thread.is_alive( ): self._wait_cond.wait( max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY)) return WaitForCommandExitCodeResponse(self._command_exit_code) finally: self._wait_cond.release() if isinstance(req, RegisterCodeResultRequest): self._fn_result = req.result return network.AckResponse() return super(BasicTaskService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, RunCommandRequest): self._wait_cond.acquire() try: if self._command_thread is None: # we add req.env to _command_env and make this available to the executed command if self._command_env: env = self._command_env.copy() self._add_envs(env, req.env) req.env = env if self._verbose >= 2: print("Task service executes command: {}".format( req.command)) if self._verbose >= 3: for key, value in req.env.items(): if 'SECRET' in key: value = '*' * len(value) print("Task service env: {} = {}".format( key, value)) # We only permit executing exactly one command, so this is idempotent. self._command_abort = threading.Event() self._command_stdout = Pipe( ) if req.capture_stdout else None self._command_stderr = Pipe( ) if req.capture_stderr else None args = (req.command, req.env, self._command_abort, self._command_stdout, self._command_stderr, self._index, req.prefix_output_with_timestamp) self._command_thread = in_thread(self._run_command, args) finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, StreamCommandOutputRequest): # Wait for command to start self.wait_for_command_start() # We only expect streaming each command output stream once concurrently if isinstance(req, StreamCommandStdOutRequest): return self.stream_output(self._command_stdout) elif isinstance(req, StreamCommandStdErrRequest): return self.stream_output(self._command_stderr) else: return CommandOutputNotCaptured() if isinstance(req, AbortCommandRequest): self._wait_cond.acquire() try: if self._command_thread is not None: self._command_abort.set() if self._command_stdout is not None: self._command_stdout.close() if self._command_stderr is not None: self._command_stderr.close() finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, NotifyInitialRegistrationCompleteRequest): self._wait_cond.acquire() try: self._initial_registration_complete = True finally: self._wait_cond.notify_all() self._wait_cond.release() return network.AckResponse() if isinstance(req, CommandExitCodeRequest): self._wait_cond.acquire() try: terminated = (self._command_thread is not None and not self._command_thread.is_alive()) return CommandExitCodeResponse( terminated, self._command_exit_code if terminated else None) finally: self._wait_cond.release() if isinstance(req, WaitForCommandExitCodeRequest): self._wait_cond.acquire() try: while self._command_thread is None or self._command_thread.is_alive( ): self._wait_cond.wait( max(req.delay, WAIT_FOR_COMMAND_MIN_DELAY)) return WaitForCommandExitCodeResponse(self._command_exit_code) finally: self._wait_cond.release() if isinstance(req, RegisterCodeResultRequest): self._fn_result = req.result return network.AckResponse() return super(BasicTaskService, self)._handle(req, client_address)
def _handle(self, req, client_address): if isinstance(req, RegisterDispatcherRequest): self._wait_cond.acquire() try: if not 0 <= req.dispatcher_id <= self._max_dispatcher_id: return IndexError( f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: ' f'{req.dispatcher_id}') if self._dispatcher_addresses[req.dispatcher_id] is not None and \ self._dispatcher_addresses[req.dispatcher_id] != req.dispatcher_address: return ValueError( f'Dispatcher with id {req.dispatcher_id} has already been registered under ' f'different address {self._dispatcher_addresses[req.dispatcher_id]}: ' f'{req.dispatcher_address}') self._dispatcher_addresses[ req.dispatcher_id] = req.dispatcher_address self._wait_cond.notify_all() finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, WaitForDispatcherRegistrationRequest): self._wait_cond.acquire() try: if not 0 <= req.dispatcher_id <= self._max_dispatcher_id: return IndexError( f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: ' f'{req.dispatcher_id}') tmout = timeout.Timeout( timeout=req.timeout, message= 'Timed out waiting for {activity}. Try to find out what takes ' 'the dispatcher so long to register or increase timeout.') while self._dispatcher_addresses[req.dispatcher_id] is None: self._wait_cond.wait(tmout.remaining()) tmout.check_time_out_for( f'dispatcher {req.dispatcher_id} to register') except TimeoutException as e: return e finally: self._wait_cond.release() return WaitForDispatcherRegistrationResponse( self._dispatcher_addresses[req.dispatcher_id]) if isinstance(req, RegisterDispatcherWorkerRequest): self._wait_cond.acquire() try: if not 0 <= req.dispatcher_id <= self._max_dispatcher_id: return IndexError( f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: ' f'{req.dispatcher_id}') self._dispatcher_worker_ids[req.dispatcher_id].update( {req.worker_id}) self._wait_cond.notify_all() finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, WaitForDispatcherWorkerRegistrationRequest): # if there is only a single dispatcher, wait for that one instead of the requested one dispatcher_id = req.dispatcher_id if self._max_dispatcher_id > 0 else 0 self._wait_cond.acquire() try: if not 0 <= req.dispatcher_id <= self._max_dispatcher_id: return IndexError( f'Dispatcher id must be within [0..{self._max_dispatcher_id}]: ' f'{req.dispatcher_id}') tmout = timeout.Timeout( timeout=req.timeout, message= 'Timed out waiting for {activity}. Try to find out what takes ' 'the workers so long to register or increase timeout.') while len(self._dispatcher_worker_ids[dispatcher_id] ) < self._workers_per_dispatcher: self._wait_cond.wait(tmout.remaining()) tmout.check_time_out_for( f'workers for dispatcher {dispatcher_id} to register') except TimeoutException as e: return e finally: self._wait_cond.release() return network.AckResponse() if isinstance(req, ShutdownRequest): in_thread(self.shutdown) return network.AckResponse() if isinstance(req, WaitForShutdownRequest): self._wait_cond.acquire() try: while not self._shutdown: self._wait_cond.wait() finally: self._wait_cond.release() return network.AckResponse() return super()._handle(req, client_address)