def add_service_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError(('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format( tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format( tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = self._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() for endpoint in network_endpoints: worker = ServiceWorker(internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker)
def main(_): logging.info('Starting TPU health monitor for container %s...', FLAGS.container) try: config.load_incluster_config() except config.ConfigException: config.load_kube_config() k8s_client = client.CoreV1Api() pod = k8s_client.read_namespaced_pod(FLAGS.pod, FLAGS.namespace) tpu_name_annotation = 'name.cloud-tpus.google.com/{}'.format( FLAGS.container) tpu_name = os.path.basename(pod.metadata.annotations[tpu_name_annotation]) tpu_client = cloud_tpu_client.Client(tpu_name, FLAGS.zone, FLAGS.project) logging.info('TPU health monitor initialized for %s.', tpu_name) if FLAGS.verbose: logging.set_verbosity(logging.DEBUG) else: logging.set_verbosity(logging.WARNING) while True: try: health = tpu_client.health() except ValueError as e: logging.error('Error getting TPU status: %s', str(e)) health = None if health == 'HEALTHY': logging.info('TPU health: %s', health) else: logging.warning('TPU health: %s', health) if not tpu_client.recoverable(): logging.warning('TPU entered un-recoverable state: %s', tpu_client.state()) break pod = k8s_client.read_namespaced_pod_status(FLAGS.pod, FLAGS.namespace) try: status = next(c for c in pod.status.container_statuses if c.name == FLAGS.container) except StopIteration: logging.fatal( 'Status for container `%s` not found in statuses:\n%s', FLAGS.container, str(pod.status)) exit(1) if getattr(status.state, 'terminated'): logging.warning('Container `%s` terminated with status:\n%s', FLAGS.container, str(status)) break time.sleep(FLAGS.interval)
def update_tpu_runtime(tpu_name, version): print(f'Updating TPU runtime to {version.tpu} ...') try: import cloud_tpu_client except ImportError: subprocess.call(['pip', 'install', 'cloud-tpu-client']) import cloud_tpu_client client = cloud_tpu_client.Client(tpu_name) client.configure_tpu_version(version.tpu) print('Done updating TPU runtime')
def _set_tpuvm_mode(self): self._tpuvm_mode = False self._tpuvm_mode_with_remote_coordinator = False accel_type = ClusterResolver.get_instance_metadata( 'instance/attributes/accelerator-type') if re.match(r'v[0-9]+-[0-9]+', accel_type): # Only VM with TPU attched will carry the accelerator-type metadata self._tpuvm_mode = True return api_version = cloud_tpu_client.Client( tpu=self._tpus[0])._get_tpu_property('apiVersion') if api_version == 'V2_ALPHA1': # Only TPUVM api version should be V2_ALPHA1 self._tpuvm_mode = True # Current vm does not carry the accelerator-type metadata but tpu specified # is a TPUVM, assume it is a remote coordinator. self._tpuvm_mode_with_remote_coordinator = True
def add_tpu_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) tpu_name = ctc.name() if ctc.state() != 'READY': raise RuntimeError(('TPU {tpu_name} is not READY yet. ' 'Re-run when all TPUs are READY').format( tpu_name=tpu_name)) if ctc.health() != 'HEALTHY': raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. ' 'Re-run when all TPUs are HEALTHY').format( tpu_name=tpu_name)) runtime_version = ctc.runtime_version() machine_type = ctc.accelerator_type() zone = ClusterResolver._parse_resource_url(ctc._full_name(), 'locations') network_endpoints = ctc.network_endpoints() if as_client_worker: ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping( tpu_name, zone, len(network_endpoints)) for endpoint in network_endpoints: if as_client_worker: internal_ip = endpoint['ipAddress'] hostname = ip_to_host_name[internal_ip] worker = ClientWorker(internal_ip=internal_ip, machine_type=machine_type, zone=zone, hostname=hostname) else: worker = ServiceWorker(internal_ip=endpoint['ipAddress'], port=endpoint['port'], machine_type=machine_type, zone=zone, runtime_version=runtime_version, tpu=tpu_name) workers.append(worker)
def wait_for_healthy_service_worker(tpu_name): ctc = cloud_tpu_client.Client(tpu=tpu_name) ctc.wait_for_healthy()
def _tpu_with_health(tpu_name): ctc = cloud_tpu_client.Client(tpu_name) if ctc.health() == health: return tpu_name