Ejemplo n.º 1
0
    def _get_internal_ip_to_hostname_mapping(tpu_name, zone, num_vm):
        """Gets TPU VM internal IP to hostname mapping.

    Currently TPU CLH does not expose any TPU host machine name. SSH to each worker and
    get that instead.

    Returns:
      A map of TPU VM internal IP to TPU VM hostname.
    """
        ip_to_host_name = {}

        def add_tpuvm_ip_to_hostname_mapping(worker_index):
            proc = subprocess.Popen([
                'gcloud', 'alpha', 'compute', 'tpus', 'tpu-vm', 'ssh',
                '--internal-ip', tpu_name, '--zone', zone, '--worker',
                str(worker_index), '--command', 'hostname; hostname -i'
            ],
                                    stdout=subprocess.PIPE)
            hostname = proc.stdout.readline().decode('utf-8').rstrip('\n')
            ip = proc.stdout.readline().decode('utf-8').rstrip('\n')
            ip_to_host_name[ip] = hostname

        xu.parallel_work(num_vm, add_tpuvm_ip_to_hostname_mapping,
                         list(range(num_vm)))
        return ip_to_host_name
Ejemplo n.º 2
0
    def _start_run(self, script_map):
        def _run_script(script_paths, client_worker):
            script_path = script_paths['remote_path']
            if self.restart_server and self.tpuvm_mode:
                kill_server = ('pkill -f "^python -m {} [0-9]+$"').format(
                    self.XRT_RUN_SERVER_PROCESS)
                self._build_and_run_ssh(kill_server, client_worker, log=False)
            exit_code = self._build_and_run_ssh([script_path], client_worker)
            if exit_code != 0:
                raise RuntimeError(
                    'Remote command exitted with code: {}'.format(exit_code))

        def _regular_health_check():
            uneven_health_timeout = xu.getenv_as(
                'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900)
            even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT',
                                               int, 1800)
            while True:
                self._check_client_mesh_health(uneven_health_timeout,
                                               even_health_timeout)
                time.sleep(self.HEARTBEAT_CHECK_PERIOD)

        threading.Thread(target=_regular_health_check, daemon=True).start()
        xu.parallel_work(len(script_map), _run_script, script_map.values(),
                         script_map.keys())
Ejemplo n.º 3
0
    def wait_for_healthy_service(self):
        def wait_for_healthy_service_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            ctc.wait_for_healthy()

        tpus = self.list_tpus_with_health('UNHEALTHY_MAINTENANCE')
        if tpus:
            xu.parallel_work(len(tpus), wait_for_healthy_service_worker, tpus)
Ejemplo n.º 4
0
    def get_tpu_workers(self, as_client_worker=False):
        """Gets TPU VM cluster info.

    Calls the TPU CLH to get TPU node data and returns list of TPU worker
    VMs internal IP addresses. If zone and project are not specified at
    ClusterResolver init time, we infer these bits from GCE metadata.

    Returns:
      A list of ServiceWorker or a list of ClientWorker.

    Raises:
      RuntimeError: If the TPU DNE or the TPU is in not in HEALTHY state.
    """
        workers = []

        def add_tpu_worker(tpu_name):
            ctc = cloud_tpu_client.Client(tpu=tpu_name)
            tpu_name = ctc.name()
            if ctc.state() != 'READY':
                raise RuntimeError(('TPU {tpu_name} is not READY yet. '
                                    'Re-run when all TPUs are READY').format(
                                        tpu_name=tpu_name))
            if ctc.health() != 'HEALTHY':
                raise RuntimeError(('TPU {tpu_name} is not HEALTHY yet. '
                                    'Re-run when all TPUs are HEALTHY').format(
                                        tpu_name=tpu_name))

            runtime_version = ctc.runtime_version()
            machine_type = ctc.accelerator_type()
            zone = ClusterResolver._parse_resource_url(ctc._full_name(),
                                                       'locations')
            network_endpoints = ctc.network_endpoints()

            if as_client_worker:
                ip_to_host_name = ClusterResolver._get_internal_ip_to_hostname_mapping(
                    tpu_name, zone, len(network_endpoints))

            for endpoint in network_endpoints:
                if as_client_worker:
                    internal_ip = endpoint['ipAddress']
                    hostname = ip_to_host_name[internal_ip]
                    worker = ClientWorker(internal_ip=internal_ip,
                                          machine_type=machine_type,
                                          zone=zone,
                                          hostname=hostname)
                else:
                    worker = ServiceWorker(internal_ip=endpoint['ipAddress'],
                                           port=endpoint['port'],
                                           machine_type=machine_type,
                                           zone=zone,
                                           runtime_version=runtime_version,
                                           tpu=tpu_name)
                workers.append(worker)

        xu.parallel_work(len(self._tpus), add_tpu_worker, self._tpus)

        return workers
Ejemplo n.º 5
0
    def list_tpus_with_health(self, health):
        def _tpu_with_health(tpu_name):
            ctc = cloud_tpu_client.Client(tpu_name)
            if ctc.health() == health:
                return tpu_name

        tpus = set()
        for service_worker in self._service_workers:
            tpus.add(service_worker._tpu)
        results = xu.parallel_work(len(tpus), _tpu_with_health, tpus)
        return [res for res in results if res]
Ejemplo n.º 6
0
    def wait_for_healthy_client(self,
                                dist_executor,
                                timeout=1200,
                                interval=10):
        def wait_for_healthy_client_worker(client_worker):
            heartbeart_check = [
                'echo', 'client_worker', '$(hostname)', 'is', 'healthy'
            ]
            check_timeout = time.time() + timeout

            def _healthy_client_worker():
                proc = multiprocessing.Process(
                    target=dist_executor._build_and_run_ssh,
                    args=(
                        heartbeart_check,
                        client_worker,
                    ))
                proc.daemon = True
                proc.start()
                proc.join(interval)

                if proc.is_alive():
                    proc.terminate()
                    return False

                return proc.exitcode == 0

            while not _healthy_client_worker():
                logging.warning(
                    'Waiting for client_worker "{}" to become healthy'.format(
                        client_worker))
                if time.time() + interval > check_timeout:
                    raise RuntimeError(
                        'Timed out waiting for client_worker {} to become healthy'
                        .format(client_worker))

            logging.warning(
                'client_worker "{}" is healthy.'.format(client_worker))

        xu.parallel_work(len(self._client_workers),
                         wait_for_healthy_client_worker, self._client_workers)
Ejemplo n.º 7
0
    def _start_run(self, script_map):
        def _run_script(script_paths, client_worker):
            script_path = script_paths['remote_path']
            exit_code = self._build_and_run_ssh([script_path], client_worker)
            if exit_code != 0:
                raise RuntimeError(
                    'Remote command exitted with code: {}'.format(exit_code))

        def _regular_health_check():
            uneven_health_timeout = xu.getenv_as(
                'XLA_UNEVEN_HEARTBEAT_TIMEOUT', int, 900)
            even_health_timeout = xu.getenv_as('XLA_EVEN_HEARTBEAT_TIMEOUT',
                                               int, 1800)
            while True:
                self._check_client_mesh_health(uneven_health_timeout,
                                               even_health_timeout)
                time.sleep(self.HEARTBEAT_CHECK_PERIOD)

        threading.Thread(target=_regular_health_check, daemon=True).start()
        xu.parallel_work(len(script_map), _run_script, script_map.values(),
                         script_map.keys())