Example #1
0
        def add_service_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            tpu_name = ctc.name()
            if ctc.state() != 'READY':
                raise RuntimeError(('TPU {tpu_name} is not READY yet. '
                                    'Re-run when all TPUs are READY').format(
                                        tpu_name=tpu_name))
            if ctc.health() != 'HEALTHY':
                raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. '
                                    'Re-run when all TPUs are HEALTHY').format(
                                        tpu_name=tpu_name))

            runtime_version = ctc.runtime_version()
            machine_type = ctc.accelerator_type()
            zone = self._parse_resource_url(ctc._full_name(), 'locations')
            network_endpoints = ctc.network_endpoints()

            for endpoint in network_endpoints:
                worker = ServiceWorker(internal_ip=endpoint['ipAddress'],
                                       port=endpoint['port'],
                                       machine_type=machine_type,
                                       zone=zone,
                                       runtime_version=runtime_version,
                                       tpu=tpu_name)
                workers.append(worker)
Example #2
0
def main(_):
    logging.info('Starting TPU health monitor for container %s...',
                 FLAGS.container)

    try:
        config.load_incluster_config()
    except config.ConfigException:
        config.load_kube_config()

    k8s_client = client.CoreV1Api()
    pod = k8s_client.read_namespaced_pod(FLAGS.pod, FLAGS.namespace)

    tpu_name_annotation = 'name.cloud-tpus.google.com/{}'.format(
        FLAGS.container)
    tpu_name = os.path.basename(pod.metadata.annotations[tpu_name_annotation])

    tpu_client = cloud_tpu_client.Client(tpu_name, FLAGS.zone, FLAGS.project)

    logging.info('TPU health monitor initialized for %s.', tpu_name)

    if FLAGS.verbose:
        logging.set_verbosity(logging.DEBUG)
    else:
        logging.set_verbosity(logging.WARNING)

    while True:
        try:
            health = tpu_client.health()
        except ValueError as e:
            logging.error('Error getting TPU status: %s', str(e))
            health = None

        if health == 'HEALTHY':
            logging.info('TPU health: %s', health)
        else:
            logging.warning('TPU health: %s', health)

            if not tpu_client.recoverable():
                logging.warning('TPU entered un-recoverable state: %s',
                                tpu_client.state())
                break

        pod = k8s_client.read_namespaced_pod_status(FLAGS.pod, FLAGS.namespace)
        try:
            status = next(c for c in pod.status.container_statuses
                          if c.name == FLAGS.container)
        except StopIteration:
            logging.fatal(
                'Status for container `%s` not found in statuses:\n%s',
                FLAGS.container, str(pod.status))
            exit(1)
        if getattr(status.state, 'terminated'):
            logging.warning('Container `%s` terminated with status:\n%s',
                            FLAGS.container, str(status))
            break

        time.sleep(FLAGS.interval)
Example #3
0
def update_tpu_runtime(tpu_name, version):
    print(f'Updating TPU runtime to {version.tpu} ...')

    try:
        import cloud_tpu_client
    except ImportError:
        subprocess.call(['pip', 'install', 'cloud-tpu-client'])
        import cloud_tpu_client

    client = cloud_tpu_client.Client(tpu_name)
    client.configure_tpu_version(version.tpu)
    print('Done updating TPU runtime')
Example #4
0
    def _set_tpuvm_mode(self):
        self._tpuvm_mode = False
        self._tpuvm_mode_with_remote_coordinator = False
        accel_type = ClusterResolver.get_instance_metadata(
            'instance/attributes/accelerator-type')
        if re.match(r'v[0-9]+-[0-9]+', accel_type):
            # Only VM with TPU attched will carry the accelerator-type metadata
            self._tpuvm_mode = True
            return

        api_version = cloud_tpu_client.Client(
            tpu=self._tpus[0])._get_tpu_property('apiVersion')
        if api_version == 'V2_ALPHA1':
            # Only TPUVM api version should be V2_ALPHA1
            self._tpuvm_mode = True
            # Current vm does not carry the accelerator-type metadata but tpu specified
            # is a TPUVM, assume it is a remote coordinator.
            self._tpuvm_mode_with_remote_coordinator = True
Example #5
0
        def add_tpu_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            tpu_name = ctc.name()
            if ctc.state() != 'READY':
                raise RuntimeError(('TPU {tpu_name} is not READY yet. '
                                    'Re-run when all TPUs are READY').format(
                                        tpu_name=tpu_name))
            if ctc.health() != 'HEALTHY':
                raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. '
                                    'Re-run when all TPUs are HEALTHY').format(
                                        tpu_name=tpu_name))

            runtime_version = ctc.runtime_version()
            machine_type = ctc.accelerator_type()
            zone = ClusterResolver._parse_resource_url(ctc._full_name(),
                                                       'locations')
            network_endpoints = ctc.network_endpoints()

            if as_client_worker:
                ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping(
                    tpu_name, zone, len(network_endpoints))

            for endpoint in network_endpoints:
                if as_client_worker:
                    internal_ip = endpoint['ipAddress']
                    hostname = ip_to_host_name[internal_ip]
                    worker = ClientWorker(internal_ip=internal_ip,
                                          machine_type=machine_type,
                                          zone=zone,
                                          hostname=hostname)
                else:
                    worker = ServiceWorker(internal_ip=endpoint['ipAddress'],
                                           port=endpoint['port'],
                                           machine_type=machine_type,
                                           zone=zone,
                                           runtime_version=runtime_version,
                                           tpu=tpu_name)
                workers.append(worker)
Example #6
0
 def wait_for_healthy_service_worker(tpu_name):
     ctc = cloud_tpu_client.Client(tpu=tpu_name)
     ctc.wait_for_healthy()
Example #7
0
 def _tpu_with_health(tpu_name):
     ctc = cloud_tpu_client.Client(tpu_name)
     if ctc.health() == health:
         return tpu_name