def initialize_tpu_system(cluster_resolver=None): """Initialize the TPU devices. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Returns: The tf.tpu.Topology object for the topology of the TPU cluster. If called inside tf.function, it returns the serialized topology object instead. Raises: RuntimeError: If running inside a tf.function. NotFoundError: If no TPU devices found in eager mode. """ job = None if cluster_resolver is None: # If no cluster resolver is specified, and running eagerly, execute the init # ops in the current device scope. if context.executing_eagerly(): curr_device = device.DeviceSpec.from_string( context.context().device_name) if curr_device.job is not None: job = "{}/replica:0/task:0".format(curr_device.job) cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name in _INITIALIZED_TPU_SYSTEMS: logging.warning( "TPU system %s has already been initialized. " "Reinitializing the TPU can cause previously created " "variables on TPU to be lost.", tpu_name) logging.info("Initializing the TPU system: %s", tpu_name) # This function looks as it is for the following non-intuitive reasons. # tpu.initialize_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # initialize the TPU system. Thus, we can't simply run tpu.initialize_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. if tpu_name not in _LOCAL_MASTERS: # Explicitly place the tpu.initialize_system in the first worker to # avoid the output node match multiple devices error. job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) if context.executing_eagerly(): @function.defun def _tpu_init_fn(): # In TF1, we usually close chips when compilation fails to clear the data # in infeed. In TF2, we don't need to do this because infeed is no longer # used, so user can recover from TPU compilation failures more smoothly. return tpu.initialize_system( job=job, compilation_failure_closes_chips=False) # The TPU_SYSTEM device must match the device used in tpu.initialize_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. try: with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access output = _tpu_init_fn() context.async_wait() except errors.InvalidArgumentError as e: raise errors.NotFoundError( None, None, "TPUs not found in the cluster. Failed in initialization: " + str(e)) # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access serialized_topology = output.numpy() elif not ops.executing_eagerly_outside_functions(): master = cluster_resolver.master() cluster_spec = cluster_resolver.cluster_spec() session_config = config_pb2.ConfigProto(allow_soft_placement=True) if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: serialized_topology = sess.run(tpu.initialize_system()) else: with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access serialized_topology = tpu.initialize_system( job=job, compilation_failure_closes_chips=False) # If initialize_tpu_system is called inside tf.function, we only return # the serialized topology object as the tf.tpu.Topology object has to be # constructed in eager mode. return serialized_topology logging.info("Finished initializing TPU system.") tpu_topology = topology.Topology(serialized=serialized_topology) _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology return tpu_topology
def shutdown_tpu_system(cluster_resolver=None): """Shuts down the TPU devices. This will clear all caches, even those that are maintained through sequential calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation cache. Args: cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver, which provides information about the TPU cluster. Raises: RuntimeError: If no TPU devices found for eager execution or if run in a tf.function. """ job = None if cluster_resolver is None: # If no cluster resolver is specified, and running eagerly, execute the init # ops in the current device scope. if context.executing_eagerly(): curr_device = device.DeviceSpec.from_string( context.context().device_name) if curr_device.job is not None: job = "{}/replica:0/task:0".format(curr_device.job) cluster_resolver = TPUClusterResolver("") assert isinstance(cluster_resolver, TPUClusterResolver) tpu_name = compat.as_text(cluster_resolver._tpu) # pylint: disable=protected-access if tpu_name not in _INITIALIZED_TPU_SYSTEMS: logging.warning( "You are shutting down a TPU system %s that has not been " "initialized." % tpu_name) logging.info("Shutting down the TPU system: %s", tpu_name) if context.executing_eagerly(): # This function looks as it is for the following non-intuitive reasons. # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger # DistributedTPURewritePass. This pass actually adds real ops that # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system # eagerly. We need to wrap it in defun and trigger the rewrite passes on it. if tpu_name not in _LOCAL_MASTERS: # Explicitly place the tpu.shutdown_system in the first worker to # avoid the output node match multiple devices error. job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name()) @function.defun def _tpu_shutdown_fn(): tpu.shutdown_system(job=job) # The TPU_SYSTEM device must match the device used in tpu.shutdown_system # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM # devices available. with ops.device(tpu._tpu_system_device_name(job)): # pylint: disable=protected-access _tpu_shutdown_fn() # Clear out the eager context caches since the memory is invalid now. logging.info("Clearing out eager caches") context.context()._clear_caches() # pylint: disable=protected-access elif not ops.executing_eagerly_outside_functions(): master = cluster_resolver.master() cluster_spec = cluster_resolver.cluster_spec() session_config = config_pb2.ConfigProto(allow_soft_placement=True) if cluster_spec: session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def()) with ops.Graph().as_default(): with session_lib.Session(config=session_config, target=master) as sess: sess.run(tpu.shutdown_system()) else: raise RuntimeError("initialize_tpu_system is not supported within " "tf.functions.") logging.info("Finished shutting down TPU system.") if tpu_name in _INITIALIZED_TPU_SYSTEMS: del _INITIALIZED_TPU_SYSTEMS[tpu_name]