Ejemplo n.º 1
0
    def test_connect(self):
        self.assertCountEqual(EXPECTED_DEVICES_PRE_CONNECT,
                              context.list_devices())

        resolver = tpu_cluster_resolver.TPUClusterResolver(
            tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project)
        remote.connect_to_cluster(resolver)

        self.assertCountEqual(EXPECTED_DEVICES_AFTER_CONNECT,
                              context.list_devices())

        tpu_strategy_util.initialize_tpu_system(resolver)
Ejemplo n.º 2
0
def get_first_tpu_host_device(cluster_resolver):
    """Get the device spec for the first TPU host."""
    if context.executing_eagerly():
        tpu_devices = sorted(
            [x for x in context.list_devices() if "device:TPU:" in x])
        if not tpu_devices:
            raise RuntimeError("Could not find any TPU devices")
        spec = tf_device.DeviceSpec.from_string(tpu_devices[0])
        task_id = spec.task
    else:
        # Session master needs to be configured and the coordinator is not part
        # of the cluster.
        task_id = 0
    if cluster_resolver.get_master() in ("", "local"):
        return "/replica:0/task:0/device:CPU:0"
    job_name = cluster_resolver.get_job_name() or "tpu_worker"
    return "/job:%s/task:%d/device:CPU:0" % (job_name, task_id)
Ejemplo n.º 3
0
def get_accelerator_devices(master, config_proto):
  """Returns accelerator devices given a master and a configuration."""
  if context.executing_eagerly():
    device_names = context.list_devices()  # list_devices returns list(string)
    devices = []
    for name in device_names:
      device_type = 'GPU'  # default device type is GPU
      device_match = DEVICE_TYPE_REGEX.match(name)
      if device_match:
        device_type = device_match.group(1)
      if device_type == 'CPU' or device_type == 'XLA_CPU':  # Filter CPUs
        continue
      devices.append(session._DeviceAttributes(name, device_type, 0, 0))  # pylint: disable=protected-access
    return devices
  else:
    with ops.Graph().as_default():
      with session.Session(master, config=config_proto) as s:
        devices = s.list_devices()
    return devices
Ejemplo n.º 4
0
def _query_tpu_system_metadata(master_address,
                               cluster_def=None,
                               query_topology=False):
    """Automatically detects the TPU system metadata in the system."""
    tpu_core_count = 0
    devices = []
    device_dict = collections.defaultdict(list)

    if context.executing_eagerly():
        device_names = context.list_devices()
        devices = []

        # We want the output type to match in both eager and session mode
        for name in device_names:
            device_match = _DEVICE_TYPE_REGEX.match(name)
            device_type = 'CPU'
            if device_match:
                device_type = device_match.group(1)
            devices.append(
                session_lib._DeviceAttributes(name, device_type, 0, 0))  # pylint: disable=protected-access
    else:
        # TODO(b/120564445): Replace with standard library for retries.
        retry_count = 1
        while True:
            logging.info(
                'Querying Tensorflow master (%s) for TPU system metadata.',
                master_address)
            try:
                with ops.Graph().as_default():
                    with session_lib.Session(
                            master_address,
                            config=get_session_config_with_timeout(
                                _PINGING_MASTER_TIMEOUT_IN_MS,
                                cluster_def)) as sess:
                        devices = sess.list_devices()
                        break
            except errors.DeadlineExceededError:
                msg = (
                    'Failed to connect to the Tensorflow master. The TPU worker may '
                    'not be ready (still scheduling) or the Tensorflow master '
                    'address is incorrect: got (%s).' % (master_address))

                # TODO(xiejw): For local or grpc master we might not need retry logic
                # here.
                if retry_count <= _RETRY_TIMES:
                    logging.warning('%s', msg)
                    logging.warning('Retrying (%d/%d).', retry_count,
                                    _RETRY_TIMES)
                    retry_count += 1
                else:
                    raise ValueError(msg)

    for device in devices:
        match = _TPU_DEVICE_REG.match(device.name)
        if match:
            host_id = match.group(1)
            core_id = match.group(2)
            device_dict[host_id].append(core_id)
            tpu_core_count += 1

    num_of_cores_per_host = 0
    if tpu_core_count:
        num_cores_per_host_set = set(
            [len(core_ids) for core_ids in device_dict.values()])
        if len(num_cores_per_host_set) != 1:
            raise RuntimeError(
                'TPU cores on each host is not same. This should not happen!. '
                'devices: {}'.format(devices))
        num_of_cores_per_host = num_cores_per_host_set.pop()

    topology = None
    if query_topology:
        if not tpu_core_count:
            raise RuntimeError(
                'Cannot find any TPU cores in the system (master address {}). '
                'This usually means the master address is incorrect or the '
                'TPU worker has some problems. Available devices: {}'.format(
                    master_address, devices))

        topology = _obtain_topology(master_address, cluster_def)

    # We sort the metadata devices so that downstream users get a sorted list
    # for creating mirrored variables correctly.
    def _sort_key(device):
        spec = tf_device.DeviceSpec.from_string(device.name)
        return (spec.job, spec.replica, spec.task, spec.device_type,
                spec.device_index)

    devices = tuple(sorted(devices, key=_sort_key))

    metadata = _TPUSystemMetadata(num_cores=tpu_core_count,
                                  num_hosts=len(device_dict),
                                  num_of_cores_per_host=num_of_cores_per_host,
                                  topology=topology,
                                  devices=devices)

    if tpu_core_count:
        logging.info('Found TPU system:')
        logging.info('*** Num TPU Cores: %d', metadata.num_cores)
        logging.info('*** Num TPU Workers: %d', metadata.num_hosts)
        logging.info('*** Num TPU Cores Per Worker: %d',
                     metadata.num_of_cores_per_host)
        for device in metadata.devices:
            logging.info('*** Available Device: %s', device)
    else:
        logging.info('Failed to find TPU: %s', metadata)
    return metadata
Ejemplo n.º 5
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("TPU system %s has already been initialized. "
                    "Reinitializing the TPU can cause previously created "
                    "variables on TPU to be lost.")

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    tpu_devices = sorted(
        [x for x in context.list_devices() if "device:TPU:" in x])

    if not tpu_devices:
      raise RuntimeError("Could not find any TPU devices")

    # Replace the remote TPU device with the remote TPU_SYSTEM system device. As
    # in the remote TPU device case, we will try to compile it instead of
    # running through optimization passes and TF Executor, but TPU_SYSTEM should
    # work.
    tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM")

    with ops.device(tpu_system_device):
      output = _tpu_init_fn()
    serialized_topology = output.numpy()
  else:
    master = cluster_resolver.master()
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology
Ejemplo n.º 6
0
def initialize_tpu_system(cluster_resolver=None):
  """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution.
  """
  if cluster_resolver is None:
    cluster_resolver = TPUClusterResolver("")
  assert isinstance(cluster_resolver, TPUClusterResolver)

  tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
  if tpu_name in _INITIALIZED_TPU_SYSTEMS:
    logging.warning("TPU system %s has already been initialized. "
                    "Reinitializing the TPU can cause previously created "
                    "variables on TPU to be lost.")

  logging.info("Initializing the TPU system.")

  if context.executing_eagerly():
    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    # The easiest way to trigger a rewrite is to run the function with
    # TPUPartitionedCallOp.
    @function.defun
    def _tpu_init_fn():
      return tpu.initialize_system()

    # We can't call _tpu_init_fn normally (because it contains just a dummy op,
    # see above) but need to define it to get it added to eager context
    # and get its assigned name.
    # pylint: disable=protected-access
    graph_func = _tpu_init_fn._get_concrete_function_internal()
    func_name = compat.as_str(graph_func._inference_function.name)
    # pylint: enable=protected-access

    tpu_devices = sorted(
        [x for x in context.list_devices() if "device:TPU:" in x])

    if not tpu_devices:
      raise RuntimeError("Could not find any TPU devices")

    with ops.device(device_util.get_host_for_device(tpu_devices[0])):
      output = tpu_functional_ops.TPUPartitionedCall(
          args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name)
    serialized_topology = output[0].numpy()
  else:
    master = cluster_resolver.master()
    session_config = config_pb2.ConfigProto(allow_soft_placement=True)
    with ops.Graph().as_default():
      with session_lib.Session(config=session_config, target=master) as sess:
        serialized_topology = sess.run(tpu.initialize_system())

  logging.info("Finished initializing TPU system.")
  tpu_topology = topology.Topology(serialized=serialized_topology)
  _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

  return tpu_topology