예제 #1
0
def initialize_tpu_system(cluster_resolver=None):
    """Initialize the TPU devices.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.
  Returns:
    The tf.tpu.Topology object for the topology of the TPU cluster. If called
    inside tf.function, it returns the serialized topology object instead.

  Raises:
    RuntimeError: If running inside a tf.function.
    NotFoundError: If no TPU devices found in eager mode.
  """
    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "TPU system %s has already been initialized. "
            "Reinitializing the TPU can cause previously created "
            "variables on TPU to be lost.", tpu_name)

    logging.info("Initializing the TPU system: %s", tpu_name)

    # This function looks as it is for the following non-intuitive reasons.
    # tpu.initialize_system creates a dummy op whose sole purpose is to trigger
    # DistributedTPURewritePass. This pass actually adds real ops that
    # initialize the TPU system. Thus, we can't simply run tpu.initialize_system
    # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
    if tpu_name not in _LOCAL_MASTERS:
        # Explicitly place the tpu.initialize_system in the first worker to
        # avoid the output node match multiple devices error.
        job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

    if context.executing_eagerly():

        @function.defun
        def _tpu_init_fn():
            # In TF1, we usually close chips when compilation fails to clear the data
            # in infeed. In TF2, we don't need to do this because infeed is no longer
            # used, so user can recover from TPU compilation failures more smoothly.
            return tpu.initialize_system(
                job=job, compilation_failure_closes_chips=False)

        # The TPU_SYSTEM device must match the device used in tpu.initialize_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        try:
            with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
                output = _tpu_init_fn()
            context.async_wait()
        except errors.InvalidArgumentError as e:
            raise errors.NotFoundError(
                None, None,
                "TPUs not found in the cluster. Failed in initialization: " +
                str(e))

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access

        serialized_topology = output.numpy()
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                serialized_topology = sess.run(tpu.initialize_system())
    else:
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            serialized_topology = tpu.initialize_system(
                job=job, compilation_failure_closes_chips=False)
            # If initialize_tpu_system is called inside tf.function, we only return
            # the serialized topology object as the tf.tpu.Topology object has to be
            # constructed in eager mode.
            return serialized_topology

    logging.info("Finished initializing TPU system.")
    tpu_topology = topology.Topology(serialized=serialized_topology)
    _INITIALIZED_TPU_SYSTEMS[tpu_name] = tpu_topology

    return tpu_topology
예제 #2
0
def shutdown_tpu_system(cluster_resolver=None):
    """Shuts down the TPU devices.

  This will clear all caches, even those that are maintained through sequential
  calls to tf.tpu.experimental.initialize_tpu_system, such as the compilation
  cache.

  Args:
    cluster_resolver: A tf.distribute.cluster_resolver.TPUClusterResolver,
        which provides information about the TPU cluster.

  Raises:
    RuntimeError: If no TPU devices found for eager execution or if run in a
        tf.function.
  """
    job = None
    if cluster_resolver is None:
        # If no cluster resolver is specified, and running eagerly, execute the init
        # ops in the current device scope.
        if context.executing_eagerly():
            curr_device = device.DeviceSpec.from_string(
                context.context().device_name)
            if curr_device.job is not None:
                job = "{}/replica:0/task:0".format(curr_device.job)

        cluster_resolver = TPUClusterResolver("")
    assert isinstance(cluster_resolver, TPUClusterResolver)

    tpu_name = compat.as_text(cluster_resolver._tpu)  # pylint: disable=protected-access
    if tpu_name not in _INITIALIZED_TPU_SYSTEMS:
        logging.warning(
            "You are shutting down a TPU system %s that has not been "
            "initialized." % tpu_name)

    logging.info("Shutting down the TPU system: %s", tpu_name)

    if context.executing_eagerly():
        # This function looks as it is for the following non-intuitive reasons.
        # tpu.shutdown_system creates a dummy op whose sole purpose is to trigger
        # DistributedTPURewritePass. This pass actually adds real ops that
        # shutdown the TPU system. Thus, we can't simply run tpu.shutdown_system
        # eagerly. We need to wrap it in defun and trigger the rewrite passes on it.
        if tpu_name not in _LOCAL_MASTERS:
            # Explicitly place the tpu.shutdown_system in the first worker to
            # avoid the output node match multiple devices error.
            job = "{}/replica:0/task:0".format(cluster_resolver.get_job_name())

        @function.defun
        def _tpu_shutdown_fn():
            tpu.shutdown_system(job=job)

        # The TPU_SYSTEM device must match the device used in tpu.shutdown_system
        # exactly, otherwise you can get errors if there are multiple TPU_SYSTEM
        # devices available.
        with ops.device(tpu._tpu_system_device_name(job)):  # pylint: disable=protected-access
            _tpu_shutdown_fn()

        # Clear out the eager context caches since the memory is invalid now.
        logging.info("Clearing out eager caches")
        context.context()._clear_caches()  # pylint: disable=protected-access
    elif not ops.executing_eagerly_outside_functions():
        master = cluster_resolver.master()
        cluster_spec = cluster_resolver.cluster_spec()

        session_config = config_pb2.ConfigProto(allow_soft_placement=True)
        if cluster_spec:
            session_config.cluster_def.CopyFrom(cluster_spec.as_cluster_def())

        with ops.Graph().as_default():
            with session_lib.Session(config=session_config,
                                     target=master) as sess:
                sess.run(tpu.shutdown_system())
    else:
        raise RuntimeError("initialize_tpu_system is not supported within "
                           "tf.functions.")

    logging.info("Finished shutting down TPU system.")
    if tpu_name in _INITIALIZED_TPU_SYSTEMS:
        del _INITIALIZED_TPU_SYSTEMS[tpu_name]