Ejemplo n.º 1
0
def _configure_tpu_version(tpu_name, version_label, new_version_id):
    """Returns the current tpu version after resetting to an optional version."""
    # The tpu_name is arbitrary / user chosen unique string for this tpu.
    logging.info('Trying to connect to tpu %s', tpu_name)
    tpu_client = client.Client(tpu=tpu_name)
    tpu_client.wait_for_healthy()

    if new_version_id:
        logging.info('Trying to reset tpu version to %s', new_version_id)
        tpu_client.configure_tpu_version(version=new_version_id)
        tpu_client.wait_for_healthy()
        logging.info('TPU healthy after version reset.')
    else:
        logging.info('Using the default tpu version id.')

    workers = tpu_client.network_endpoints()
    if workers:
        ip_addr = workers[0]['ipAddress']
        url = 'http://{}:8475/requestversion'.format(ip_addr)
        return _get_version_info(url, version_label)
    else:
        logging.error('No tpu endpoint info')
        return {
            'url': '',
            'hash': '',
            'branch': version_label,
            'piper_id': '',
        }
Ejemplo n.º 2
0
def get_tpu_version(tpu_address):
    """Returns the current software version on tpu."""
    logging.info('Trying to connect to tpu %s', tpu_address)
    tpu_client = client.Client(tpu=tpu_address)
    tpu_client.wait_for_healthy()
    workers = tpu_client.network_endpoints()
    if workers:
        ip_addr = workers[0]['ipAddress']
        url = 'http://{}:8475/requestversion'.format(ip_addr)
        return _get_version_info(url)
    else:
        logging.error('No tpu endpoint info')
        return {
            'url': '',
            'hash': '',
            'branch': '',
            'piper_id': '',
        }
Ejemplo n.º 3
0
    def __init__(self,
                 tpu=None,
                 zone=None,
                 project=None,
                 job_name='worker',
                 coordinator_name=None,
                 coordinator_address=None,
                 credentials='default',
                 service=None,
                 discovery_url=None):
        """Creates a new TPUClusterResolver object.

    The ClusterResolver will then use the parameters to query the Cloud TPU APIs
    for the IP addresses and ports of each Cloud TPU listed.

    Args:
      tpu: A string corresponding to the TPU to use. It can be the TPU name or
        TPU worker gRPC address. If not set, it will try automatically resolve
        the TPU address on Cloud TPUs.
      zone: Zone where the TPUs are located. If omitted or empty, we will assume
        that the zone of the TPU is the same as the zone of the GCE VM, which we
        will try to discover from the GCE metadata service.
      project: Name of the GCP project containing Cloud TPUs. If omitted or
        empty, we will try to discover the project name of the GCE VM from the
        GCE metadata service.
      job_name: Name of the TensorFlow job the TPUs belong to.
      coordinator_name: The name to use for the coordinator. Set to None if the
        coordinator should not be included in the computed ClusterSpec.
      coordinator_address: The address of the coordinator (typically an ip:port
        pair). If set to None, a TF server will be started. If coordinator_name
        is None, a TF server will not be started even if coordinator_address is
        None.
      credentials: GCE Credentials. If None, then we use default credentials
        from the oauth2client
      service: The GCE API object returned by the googleapiclient.discovery
        function. If you specify a custom service object, then the credentials
        parameter will be ignored.
      discovery_url: A URL template that points to the location of the discovery
        service. It should have two parameters {api} and {apiVersion} that when
        filled in produce an absolute URL to the discovery document for that
        service. The environment variable 'TPU_API_DISCOVERY_URL' will override
        this.

    Raises:
      ImportError: If the googleapiclient is not installed.
      ValueError: If no TPUs are specified.
      RuntimeError: If an empty TPU name is specified and this is running in a
        Google Cloud environment.
    """

        self._cloud_tpu_client = client.Client(tpu=tpu,
                                               zone=zone,
                                               project=project,
                                               credentials=credentials,
                                               service=service,
                                               discovery_url=discovery_url)

        self._tpu = self._cloud_tpu_client.name()
        # By default the task_type is 'worker` and the task_id is 0 (which is the
        # first worker in the task).
        self.task_type = job_name
        self.task_id = 0
        self._coordinator_name = coordinator_name
        if (coordinator_name and not coordinator_address):
            self._start_local_server()
        else:
            self._coordinator_address = coordinator_address