def _get_internal_ip_to_hostname_mapping(tpu_name, zone, num_vm): """Gets TPU VM internal IP to hostname mapping. Currently TPU CLH does not expose any TPU host machine name. SSH to each worker and get that instead. Returns: A map of TPU VM internal IP to TPU VM hostname. """ ip_to_host_name = {} def add_tpuvm_ip_to_hostname_mapping(worker_index): proc = subprocess.Popen([ 'gcloud', 'alpha', 'compute', 'tpus', 'tpu-vm', 'ssh', '--internal-ip', tpu_name, '--zone', zone, '--worker', str(worker_index), '--command', 'hostname; hostname -i' ], stdout=subprocess.PIPE) hostname = proc.stdout.readline().decode('utf-8').rstrip('\n') ip = proc.stdout.readline().decode('utf-8').rstrip('\n') ip_to_host_name[ip] = hostname xu.parallel_work(num_vm, add_tpuvm_ip_to_hostname_mapping, list(range(num_vm))) return ip_to_host_name
def _start_run(self, script_map): def _run_script(script_paths, client_worker): script_path = script_paths['remote_path'] if self.restart_server and self.tpuvm_mode: kill_server = ('pkill -f "^python -m {} [0-9]+$"').format( self.XRT_RUN_SERVER_PROCESS) self._build_and_run_ssh(kill_server, client_worker, log=False) exit_code = self._build_and_run_ssh([script_path], client_worker) if exit_code != 0: raise RuntimeError( 'Remote command exitted with code: {}'.format(exit_code)) def _regular_health_check(): uneven_health_timeout = xu.getenv_as( 'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900) even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT', int, 1800) while True: self._check_client_mesh_health(uneven_health_timeout, even_health_timeout) time.sleep(self.HEARTBEAT_CHECK_PERIOD) threading.Thread(target=_regular_health_check, daemon=True).start() xu.parallel_work(len(script_map), _run_script, script_map.values(), script_map.keys())
def wait_for_healthy_service(self): def wait_for_healthy_service_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) ctc.wait_for_healthy() tpus = self.list_tpus_with_health('UNHEALTHY_MAINTENANCE') if tpus: xu.parallel_work(len(tpus), wait_for_healthy_service_worker, tpus)
def get_tpu_workers(self, as_client_worker=False): """Gets TPU VM cluster info. Calls the TPU CLH to get TPU node data and returns list of TPU worker VMs internal IP addresses. If zone and project are not specified at ClusterResolver init time, we infer these bits from GCE metadata. Returns: A list of ServiceWorker or a list of ClientWorker. Raises: RuntimeError: If the TPU DNE or the TPU is in not in HEALTHY state. """ workers = [] def add_tpu_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError(('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format( tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format( tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() if as_client_worker: ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping( tpu_name, zone, len(network_endpoints)) for endpoint in network_endpoints: if as_client_worker: internal_ip = endpoint['ipAddress'] hostname = ip_to_host_name[internal_ip] worker = ClientWorker(internal_ip=internal_ip, machine_type=machine_type, zone=zone, hostname=hostname) else: worker = ServiceWorker(internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker) xu.parallel_work(len(self._tpus), add_tpu_worker, self._tpus) return workers
def list_tpus_with_health(self, health): def _tpu_with_health(tpu_name): ctc = cloud_tpu_client.Client(tpu_name) if ctc.health() == health: return tpu_name tpus = set() for service_worker in self._service_workers: tpus.add(service_worker._tpu) results = xu.parallel_work(len(tpus), _tpu_with_health, tpus) return [res for res in results if res]
def wait_for_healthy_client(self, dist_executor, timeout=1200, interval=10): def wait_for_healthy_client_worker(client_worker): heartbeart_check = [ 'echo', 'client_worker', '$(hostname)', 'is', 'healthy' ] check_timeout = time.time() + timeout def _healthy_client_worker(): proc = multiprocessing.Process( target=dist_executor._build_and_run_ssh, args=( heartbeart_check, client_worker, )) proc.daemon = True proc.start() proc.join(interval) if proc.is_alive(): proc.terminate() return False return proc.exitcode == 0 while not _healthy_client_worker(): logging.warning( 'Waiting for client_worker "{}" to become healthy'.format( client_worker)) if time.time() + interval > check_timeout: raise RuntimeError( 'Timed out waiting for client_worker {} to become healthy' .format(client_worker)) logging.warning( 'client_worker "{}" is healthy.'.format(client_worker)) xu.parallel_work(len(self._client_workers), wait_for_healthy_client_worker, self._client_workers)
def _start_run(self, script_map): def _run_script(script_paths, client_worker): script_path = script_paths['remote_path'] exit_code = self._build_and_run_ssh([script_path], client_worker) if exit_code != 0: raise RuntimeError( 'Remote command exitted with code: {}'.format(exit_code)) def _regular_health_check(): uneven_health_timeout = xu.getenv_as( 'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900) even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT', int, 1800) while True: self._check_client_mesh_health(uneven_health_timeout, even_health_timeout) time.sleep(self.HEARTBEAT_CHECK_PERIOD) threading.Thread(target=_regular_health_check, daemon=True).start() xu.parallel_work(len(script_map), _run_script, script_map.values(), script_map.keys())