def test_connect(self): self.assertCountEqual(EXPECTED_DEVICES_PRE_CONNECT, context.list_devices()) resolver = tpu_cluster_resolver.TPUClusterResolver( tpu=FLAGS.tpu, zone=FLAGS.zone, project=FLAGS.project) remote.connect_to_cluster(resolver) self.assertCountEqual(EXPECTED_DEVICES_AFTER_CONNECT, context.list_devices()) tpu_strategy_util.initialize_tpu_system(resolver)
def get_first_tpu_host_device(cluster_resolver): """Get the device spec for the first TPU host.""" if context.executing_eagerly(): tpu_devices = sorted( [x for x in context.list_devices() if "device:TPU:" in x]) if not tpu_devices: raise RuntimeError("Could not find any TPU devices") spec = tf_device.DeviceSpec.from_string(tpu_devices[0]) task_id = spec.task else: # Session master needs to be configured and the coordinator is not part # of the cluster. task_id = 0 if cluster_resolver.get_master() in ("", "local"): return "/replica:0/task:0/device:CPU:0" job_name = cluster_resolver.get_job_name() or "tpu_worker" return "/job:%s/task:%d/device:CPU:0" % (job_name, task_id)
def get_accelerator_devices(master, config_proto): """Returns accelerator devices given a master and a configuration.""" if context.executing_eagerly(): device_names = context.list_devices() # list_devices returns list(string) devices = [] for name in device_names: device_type = 'GPU' # default device type is GPU device_match = DEVICE_TYPE_REGEX.match(name) if device_match: device_type = device_match.group(1) if device_type == 'CPU' or device_type == 'XLA_CPU': # Filter CPUs continue devices.append(session._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access return devices else: with ops.Graph().as_default(): with session.Session(master, config=config_proto) as s: devices = s.list_devices() return devices
def _query_tpu_system_metadata(master_address, cluster_def=None, query_topology=False): """Automatically detects the TPU system metadata in the system.""" tpu_core_count = 0 devices = [] device_dict = collections.defaultdict(list) if context.executing_eagerly(): device_names = context.list_devices() devices = [] # We want the output type to match in both eager and session mode for name in device_names: device_match = _DEVICE_TYPE_REGEX.match(name) device_type = 'CPU' if device_match: device_type = device_match.group(1) devices.append( session_lib._DeviceAttributes(name, device_type, 0, 0)) # pylint: disable=protected-access else: # TODO(b/120564445): Replace with standard library for retries. retry_count = 1 while True: logging.info( 'Querying Tensorflow master (%s) for TPU system metadata.', master_address) try: with ops.Graph().as_default(): with session_lib.Session( master_address, config=get_session_config_with_timeout( _PINGING_MASTER_TIMEOUT_IN_MS, cluster_def)) as sess: devices = sess.list_devices() break except errors.DeadlineExceededError: msg = ( 'Failed to connect to the Tensorflow master. The TPU worker may ' 'not be ready (still scheduling) or the Tensorflow master ' 'address is incorrect: got (%s).' % (master_address)) # TODO(xiejw): For local or grpc master we might not need retry logic # here. if retry_count <= _RETRY_TIMES: logging.warning('%s', msg) logging.warning('Retrying (%d/%d).', retry_count, _RETRY_TIMES) retry_count += 1 else: raise ValueError(msg) for device in devices: match = _TPU_DEVICE_REG.match(device.name) if match: host_id = match.group(1) core_id = match.group(2) device_dict[host_id].append(core_id) tpu_core_count += 1 num_of_cores_per_host = 0 if tpu_core_count: num_cores_per_host_set = set( [len(core_ids) for core_ids in device_dict.values()]) if len(num_cores_per_host_set) != 1: raise RuntimeError( 'TPU cores on each host is not same. This should not happen!. ' 'devices: {}'.format(devices)) num_of_cores_per_host = num_cores_per_host_set.pop() topology = None if query_topology: if not tpu_core_count: raise RuntimeError( 'Cannot find any TPU cores in the system (master address {}). ' 'This usually means the master address is incorrect or the ' 'TPU worker has some problems. Available devices: {}'.format( master_address, devices)) topology = _obtain_topology(master_address, cluster_def) # We sort the metadata devices so that downstream users get a sorted list # for creating mirrored variables correctly. def _sort_key(device): spec = tf_device.DeviceSpec.from_string(device.name) return (spec.job, spec.replica, spec.task, spec.device_type, spec.device_index) devices = tuple(sorted(devices, key=_sort_key)) metadata = _TPUSystemMetadata(num_cores=tpu_core_count, num_hosts=len(device_dict), num_of_cores_per_host=num_of_cores_per_host, topology=topology, devices=devices) if tpu_core_count: logging.info('Found TPU system:') logging.info('*** Num TPU Cores: %d', metadata.num_cores) logging.info('*** Num TPU Workers: %d', metadata.num_hosts) logging.info('*** Num TPU Cores Per Worker: %d', metadata.num_of_cores_per_host) for device in metadata.devices: logging.info('*** Available Device: %s', device) else: logging.info('Failed to find TPU: %s', metadata) return metadata
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning("TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.") logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. @function.defun def _tpu_init_fn(): return tpu.initialize_system() tpu_devices = sorted( [x for x in context.list_devices() if "device:TPU:" in x]) if not tpu_devices: raise RuntimeError("Could not find any TPU devices") # Replace the remote TPU device with the remote TPU_SYSTEM system device. As # in the remote TPU device case, we will try to compile it instead of # running through optimization passes and TF Executor, but TPU_SYSTEM should # work. tpu_system_device = tpu_devices[0].replace("TPU", "TPU_SYSTEM") with ops.device(tpu_system_device): output = _tpu_init_fn() serialized_topology = output.numpy() else: master = cluster_resolver.master() session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology
def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution. """ if cluster_resolver is None: cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning("TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.") logging.info("Initializing the TPU system.") if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. # The easiest way to trigger a rewrite is to run the function with # TPUPartitionedCallOp. @function.defun def _tpu_init_fn(): return tpu.initialize_system() # We can't call _tpu_init_fn normally (because it contains just a dummy op, # see above) but need to define it to get it added to eager context # and get its assigned name. # pylint: disable=protected-access graph_func = _tpu_init_fn._get_concrete_function_internal() func_name = compat.as_str(graph_func._inference_function.name) # pylint: enable=protected-access tpu_devices = sorted( [x for x in context.list_devices() if "device:TPU:" in x]) if not tpu_devices: raise RuntimeError("Could not find any TPU devices") with ops.device(device_util.get_host_for_device(tpu_devices[0])): output = tpu_functional_ops.TPUPartitionedCall( args=[], device_ordinal=0, Tout=[dtypes.string], f=func_name) serialized_topology = output[0].numpy() else: master = cluster_resolver.master() session_config = config_pb2.ConfigProto(allow_soft_placement=True) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology