def _fake_mirrored(value, devices):
    """Create a faked Mirrored object for testing.

  All components of the returned Mirrored have the same objects, which is not
  true in reality.
  """
    devices = _get_devices(devices)
    values = []
    for d in devices:
        with ops.device(d):
            values.append(array_ops.identity(value))
    return distribute_utils.regroup(values, wrap_class=value_lib.Mirrored)
Example #2
0
    def testNamedTuple(self):

        # We include toy implementations of Scaffold and EstimatorSpec to
        # avoid a dependency on Estimator here.

        class Scaffold(object):
            pass

        class EstimatorSpec(
                collections.namedtuple(
                    "EstimatorSpec",
                    ["mode", "loss", "train_op", "scaffold"])):
            def __new__(cls, mode, loss, train_op, scaffold=None):
                return super(EstimatorSpec, cls).__new__(cls,
                                                         mode=mode,
                                                         loss=loss,
                                                         train_op=train_op,
                                                         scaffold=scaffold
                                                         or Scaffold())

        with context.graph_mode(), ops.Graph().as_default():
            created_estimator_specs = []

            for device_id in range(3):
                spec = EstimatorSpec(mode=mode_keys.EstimatorModeKeys.TRAIN,
                                     loss=constant_op.constant(device_id / 2),
                                     train_op=array_ops.identity(
                                         constant_op.constant(device_id)))
                created_estimator_specs.append(spec)

            merged_estimator_spec = distribute_utils.regroup(
                created_estimator_specs)

            self.assertIsInstance(merged_estimator_spec, EstimatorSpec)
            self.assertEqual(mode_keys.EstimatorModeKeys.TRAIN,
                             merged_estimator_spec.mode)
            for device_id in range(3):
                self.assertEqual(created_estimator_specs[device_id].loss,
                                 merged_estimator_spec.loss.values[device_id])
                self.assertEqual(
                    created_estimator_specs[device_id].train_op,
                    merged_estimator_spec.train_op.values[device_id])
                # Scaffold is populated by `EstimatorSpec.__new__`.
                self.assertEqual(
                    created_estimator_specs[device_id].scaffold,
                    merged_estimator_spec.scaffold.values[device_id])
                self.assertIsInstance(
                    created_estimator_specs[device_id].scaffold, Scaffold)
                # Also test that we can undo the merge using select_replica()
                self.assertEqual(
                    created_estimator_specs[device_id],
                    distribute_utils.select_replica(device_id,
                                                    merged_estimator_spec))
def simple_broadcast(value, destinations, always_mirrored=False):
  """Broadcast `value` to `destinations` using simple copies."""
  devices = get_devices_from(destinations)
  if len(devices) == 1 and not always_mirrored:
    return cross_device_utils.copy_tensor_or_indexed_slices_to_device(
        value, devices[0])
  else:
    value_updates = []
    for d in devices:
      value_updates.append(
          cross_device_utils.copy_tensor_or_indexed_slices_to_device(value, d))
    return distribute_utils.regroup(value_updates,
                                    wrap_class=value_lib.Mirrored)
  def batch_reduce(self,
                   reduce_op,
                   value_destination_pairs,
                   experimental_hints=None):
    """Reduce PerReplica objects in a batch.

    Reduce each first element in `value_destination_pairs` to each second
    element which indicates the destinations.

    This can be faster than multiple individual `reduce`s because we can
    fuse several tensors into one or multiple packs before reduction.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how the
        `per_replica_value` will be reduced.
      value_destination_pairs: A list or a tuple of PerReplica objects (or
        tensors with device set if there is one device) and destinations.
      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
        to perform collective operations.

    Returns:
      a list of Mirrored objects.

    Raises:
      ValueError: if `value_destination_pairs` is not an iterable of
        tuples of PerReplica objects and destinations.
    """
    # TODO(yuefengz): if destinations are different, split into several
    # `_batch_reduce` invocations.
    if not _validate_value_destination_pairs(value_destination_pairs):
      # If the first element of each pair is a tensor, we try to turn it into a
      # PerReplica object.
      value_destination_pairs = _normalize_value_destination_pairs(
          value_destination_pairs)

    for _, d in value_destination_pairs:
      validate_destinations(d)

    # Shortcut all PerReplica objects only contain one value.
    if self._num_between_graph_workers == 1 and _all_devices_match(
        value_destination_pairs) and len(
            value_destination_pairs[0][0].values) == 1:
      return [
          distribute_utils.regroup(v.values, wrap_class=value_lib.Mirrored)
          for v, _ in value_destination_pairs
      ]

    if experimental_hints is None:
      experimental_hints = collective_util.Hints()
    return self.batch_reduce_implementation(reduce_op, value_destination_pairs,
                                            experimental_hints)
Example #5
0
    def _call_for_each_replica(self, fn, args, kwargs):
        # For now, `fn` must be an @tf.function.
        # TODO(josh11b): Relax this restriction?  Main problem is if
        # (a) executing eagerly, (b) `fn` not @tf.function, and
        # (c) executed frequently.
        assert isinstance(fn, def_function.Function)

        if _outside_run_graph() is not None:
            # Nested case, should just use outer function's context for things like
            # the current replica index.
            # TODO(josh11b): Test this case!
            with MirroredFunctionReplicaContext(self._container_strategy()):
                results = fn(*nest.map_structure(_unwrap_tensors, args),
                             **nest.map_structure(_unwrap_tensors, kwargs))
                return nest.map_structure(_wrap_tensors, results)

        _replica_index.graph_outside_run = ops.get_default_graph()
        return_values = []

        try:
            with MirroredFunctionReplicaContext(self._container_strategy()):
                for index, device in enumerate(self._devices):
                    _replica_index.current = index
                    with ops.device(device):
                        if context.executing_eagerly():
                            # NOTE: These functions need to execute concurrently if they
                            # use a collective op. This is a particular concern with eager
                            # execution.
                            with context.execution_mode(context.ASYNC):
                                return_values.append(
                                    fn(
                                        *distribute_utils.select_replica(
                                            index, args),
                                        **distribute_utils.select_replica(
                                            index, kwargs)))
                        else:
                            return_values.append(
                                fn(
                                    *distribute_utils.select_replica(
                                        index, args),
                                    **distribute_utils.select_replica(
                                        index, kwargs)))
        finally:
            _replica_index.graph_outside_run = None
            _replica_index.current = None

        return distribute_utils.regroup(return_values)
Example #6
0
def _make_per_replica(values, devices, regroup=False):
  devices = _get_devices(devices)
  assert len(values) == len(devices)

  # We simulate the result of regroup called on PerReplica which strips the
  # PerReplica wrapper if it has only one value.
  if len(values) == 1 and regroup:
    with ops.device(devices[0]):
      placed_v = array_ops.identity(values[0])
    return placed_v

  index = []
  for d, v in zip(devices, values):
    with ops.device(d):
      placed_v = array_ops.identity(v)
    index.append(placed_v)
  return distribute_utils.regroup(index)
Example #7
0
    def reduce(self,
               reduce_op,
               per_replica_value,
               destinations,
               experimental_hints=None):
        """Reduce `per_replica_value` to `destinations`.

    It runs the reduction operation defined by `reduce_op` and put the
    result on `destinations`.

    Args:
      reduce_op: An instance of `tf.distribute.ReduceOp` that indicates how
        per_replica_value will be reduced.
      per_replica_value: A `tf.distribute.DistributedValues` object or a tensor
        with device set.
      destinations: the reduction destinations.
      experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints
        to perform collective operations.

    Returns:
      a Mirrored object.

    Raises:
      ValueError: if per_replica_value can't be converted to a PerReplica
        object or if destinations aren't strings, Variables or DistributedValues
    """
        if not isinstance(per_replica_value, value_lib.DistributedValues):
            per_replica_value = _make_tensor_into_per_replica(
                per_replica_value)

        validate_destinations(destinations)

        # Shortcut if `per_replica_value` only contains one value.
        if self._num_between_graph_workers == 1 and len(
                per_replica_value.values) == 1 and _devices_match(
                    per_replica_value, destinations):
            with ops.device(per_replica_value.values[0].device):
                v = array_ops.identity(per_replica_value.values[0])
            return distribute_utils.regroup((v, ),
                                            wrap_class=value_lib.Mirrored)

        if experimental_hints is None:
            experimental_hints = collective_util.Hints()
        return self.reduce_implementation(reduce_op, per_replica_value,
                                          destinations, experimental_hints)
    def testSameId(self):
        foo = object()
        result = distribute_utils.regroup((("a", foo), ("b", foo)))
        self.assertIsInstance(result, tuple)
        self.assertLen(result, 2)
        self._is_per_replica(result[0], ["a", "b"])
        self.assertIs(foo, result[1])

        # Test select_replica(), should undo the merge done by regroup().
        result_0 = distribute_utils.select_replica(0, result)
        self.assertIsInstance(result_0, tuple)
        self.assertLen(result_0, 2)
        self.assertEqual("a", result_0[0])
        self.assertIs(foo, result_0[1])
        result_1 = distribute_utils.select_replica(1, result)
        self.assertIsInstance(result_1, tuple)
        self.assertLen(result_1, 2)
        self.assertEqual("b", result_1[0])
        self.assertIs(foo, result_1[1])
Example #9
0
    def testRegroupCollectionsMapping(self):
        class CollectionsMappingBasedClass(collections.Mapping):
            """Class inherited from collections.Mapping."""
            def __init__(self, *args, **kwargs):
                self._d = dict(*args, **kwargs)

            def __getitem__(self, key):
                return self._d.__getitem__(key)

            def __iter__(self):
                return iter(self._d)

            def __len__(self):
                return len(self._d)

        result = distribute_utils.regroup(
            (CollectionsMappingBasedClass(a="a1", b="b1"),
             CollectionsMappingBasedClass(a="a2", b="b2")))
        self.assertIsInstance(result, CollectionsMappingBasedClass)
        self._is_per_replica(result["a"], ["a1", "a2"])
        self._is_per_replica(result["b"], ["b1", "b2"])
Example #10
0
    def _do_batch_all_reduce_sparse(self, reduce_op, per_replica_values,
                                    experimental_hints):
        """All-reduce IndexedSlices across all workers in a batch."""

        logging.log_first_n(
            logging.INFO, "Collective batch_all_reduce for IndexedSlices: "
            "%d all-reduces, group_size = %d" %
            (len(per_replica_values), self._group_size), 10)

        # Pass self._communication to the runtime as a communication hint.
        communication_hint = self._communication.value
        # For now, we use NCCL only when batch_size > 1.
        # TODO(b/132575814): switch to NCCL for all collectives when communication
        # is NCCL.
        if self._communication == CollectiveCommunication.NCCL and len(
                per_replica_values) == 1:
            communication_hint = CollectiveCommunication.AUTO.value

        gathered_values = []
        with self._lock, ops.name_scope("allreduce"):
            for per_replica in per_replica_values:
                gathered_values.append(
                    cross_device_utils.build_collective_gather_indexed_slices(
                        per_replica.values,
                        self._devices,
                        self._group_size,
                        self._collective_keys,
                        communication_hint,
                        timeout=experimental_hints.timeout_seconds))

        mirrored = []
        for value in gathered_values:
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # Assume each worker has the same number of replicas.
                for i, v in enumerate(value):
                    with ops.device(v.device):
                        value[i].values = value[i].values / self._group_size
            mirrored.append(
                distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
        return mirrored
Example #11
0
def _ungroup_and_make_mirrored(grouped_reduced,
                               destinations,
                               reduce_op,
                               num_between_graph_workers=1):
    """Ungroup results from all-reduce and make Mirrored objects.

  Each all-reduce result will be divided by the number of destinations before
  Mirrored objects are created if reduce_op is "mean".

  Args:
    grouped_reduced: a list of lists, each sublist has components for each
      device, paired with a None. It is the result from
      cross_device_utils.aggregate_gradients_using*.
    destinations: a value to colocate the result with.
    reduce_op: Indicates how values will be aggregated. Accepted values
      are `tf.distribute.ReduceOp.SUM`, `tf.distribute.ReduceOp.MEAN`.
    num_between_graph_workers: number of workers in the between-graph
      replication.

  Returns:
    a list of Mirrored objects.
  """
    num_replicas = len(
        get_devices_from(destinations)) * num_between_graph_workers
    index = [[] for _ in range(len(grouped_reduced[0]))]
    for per_replica_reduced in grouped_reduced:
        for i, (v, _) in enumerate(per_replica_reduced):
            if reduce_op == reduce_util.ReduceOp.MEAN:
                with ops.device(v.device):
                    index[i].append(v / num_replicas)
            else:
                index[i].append(v)
    return [
        distribute_utils.regroup(v, wrap_class=value_lib.Mirrored)
        for v in index
    ]
Example #12
0
def _call_for_each_replica(distribution, fn, args, kwargs):
    """Run `fn` in separate threads, once per replica/worker device.

  Args:
    distribution: the DistributionStrategy object.
    fn: function to run (will be run once per replica, each in its own thread).
    args: positional arguments for `fn`
    kwargs: keyword arguments for `fn`.

  Returns:
    Merged return value of `fn` across all replicas.

  Raises:
    RuntimeError: If fn() calls get_replica_context().merge_call() a different
        number of times from the available devices.
  """
    # TODO(josh11b): Add this option once we add synchronization to variable
    # creation. Until then, this is pretty unsafe to use.
    run_concurrently = False
    if not context.executing_eagerly():
        # Needed for per-thread device, etc. contexts in graph mode.
        ops.get_default_graph().switch_to_thread_local()

    coord = coordinator.Coordinator(
        clean_stop_exception_types=(_RequestedStop, ))

    shared_variable_store = {}
    devices = distribution.extended.worker_devices

    thread_local_callables = _get_thread_local_configuration_callable()

    # TODO(isaprykin): Create these threads once instead of during every call.
    threads = []
    for index in range(len(devices)):
        variable_creator_fn = shared_variable_creator.make_fn(
            shared_variable_store, index)
        t = _MirroredReplicaThread(
            distribution, coord, index, devices, variable_creator_fn, fn,
            distribute_utils.caching_scope_local,
            distribute_utils.select_replica(index, args),
            distribute_utils.select_replica(index,
                                            kwargs), thread_local_callables)
        threads.append(t)

    for t in threads:
        t.start()

    # When `fn` starts `should_run` event is set on _MirroredReplicaThread
    # (`MRT`) threads. The execution waits until
    # `MRT.has_paused` is set, which indicates that either `fn` is
    # complete or a `get_replica_context().merge_call()` is called.  If `fn` is
    # complete, then `MRT.done` is set to True.  Otherwise, arguments
    # of `get_replica_context().merge_call` from all paused threads are grouped
    # and the `merge_fn` is performed.  Results of the
    # `get_replica_context().merge_call` are then set to `MRT.merge_result`.
    # Each such `get_replica_context().merge_call` call returns the
    # `MRT.merge_result` for that thread when `MRT.should_run` event
    # is reset again. Execution of `fn` resumes.

    try:
        with coord.stop_on_exception():
            all_done = False
            while not all_done and not coord.should_stop():
                done = []
                if run_concurrently:
                    for t in threads:
                        t.should_run.set()
                    for t in threads:
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                else:
                    for t in threads:
                        t.should_run.set()
                        t.has_paused.wait()
                        t.has_paused.clear()
                        if coord.should_stop():
                            return None
                        done.append(t.done)
                if coord.should_stop():
                    return None
                all_done = all(done)
                if not all_done:
                    if any(done):
                        raise RuntimeError(
                            "Some replicas made a different number of "
                            "replica_context().merge_call() calls.")
                    # get_replica_context().merge_call() case
                    merge_args = distribute_utils.regroup(
                        tuple(t.merge_args for t in threads))
                    merge_kwargs = distribute_utils.regroup(
                        tuple(t.merge_kwargs for t in threads))
                    # We capture the name_scope of the MRT when we call merge_fn
                    # to ensure that if we have opened a name scope in the MRT,
                    # it will be respected when executing the merge function. We only
                    # capture the name_scope from the first MRT and assume it is
                    # the same for all other MRTs.
                    mtt_captured_name_scope = threads[0].captured_name_scope
                    mtt_captured_var_scope = threads[0].captured_var_scope
                    # Capture and merge the control dependencies from all the threads.
                    mtt_captured_control_deps = set()
                    for t in threads:
                        mtt_captured_control_deps.update(
                            t.captured_control_deps)
                    with ops.name_scope(mtt_captured_name_scope),\
                        ops.control_dependencies(mtt_captured_control_deps), \
                        variable_scope.variable_scope(mtt_captured_var_scope):
                        merge_result = threads[0].merge_fn(
                            distribution, *merge_args, **merge_kwargs)
                    for r, t in enumerate(threads):
                        t.merge_result = distribute_utils.select_replica(
                            r, merge_result)
    finally:
        for t in threads:
            t.should_run.set()
        coord.join(threads)

    return distribute_utils.regroup(tuple(t.main_result for t in threads))
Example #13
0
    def _experimental_run_steps_on_iterator(self,
                                            fn,
                                            iterator,
                                            iterations,
                                            initial_loop_values=None):
        if initial_loop_values is None:
            initial_loop_values = {}
        initial_loop_values = nest.flatten(initial_loop_values)

        ctx = input_lib.MultiStepContext()

        def body(i, *args):
            """A wrapper around `fn` to create the while loop body."""
            del args
            fn_result = fn(ctx, iterator.get_next())
            for (name, output) in ctx.last_step_outputs.items():
                # Convert all outputs to tensors, potentially from `DistributedValues`.
                ctx.last_step_outputs[name] = self._local_results(output)
            flat_last_step_outputs = nest.flatten(ctx.last_step_outputs)
            with ops.control_dependencies([fn_result]):
                return [i + 1] + flat_last_step_outputs

        # We capture the control_flow_context at this point, before we run `fn`
        # inside a while_loop. This is useful in cases where we might need to exit
        # these contexts and get back to the outer context to do some things, for
        # e.g. create an op which should be evaluated only once at the end of the
        # loop on the host. One such usage is in creating metrics' value op.
        self._outer_control_flow_context = (
            ops.get_default_graph()._get_control_flow_context())  # pylint: disable=protected-access

        cond = lambda i, *args: i < iterations
        i = constant_op.constant(0)
        loop_result = control_flow_ops.while_loop(cond,
                                                  body,
                                                  [i] + initial_loop_values,
                                                  name="",
                                                  parallel_iterations=1,
                                                  back_prop=False,
                                                  swap_memory=False,
                                                  return_same_structure=True)
        del self._outer_control_flow_context

        ctx.run_op = control_flow_ops.group(loop_result)

        # Convert the last_step_outputs from a list to the original dict structure
        # of last_step_outputs.
        last_step_tensor_outputs = loop_result[1:]
        last_step_tensor_outputs_dict = nest.pack_sequence_as(
            ctx.last_step_outputs, last_step_tensor_outputs)

        for name, reduce_op in ctx._last_step_outputs_reduce_ops.items():  # pylint: disable=protected-access
            output = last_step_tensor_outputs_dict[name]
            # For outputs that have already been reduced, wrap them in a Mirrored
            # container, else in a PerReplica container.
            if reduce_op is None:
                last_step_tensor_outputs_dict[name] = distribute_utils.regroup(
                    output)
            else:
                assert len(output) == 1
                last_step_tensor_outputs_dict[name] = output[0]

        ctx._set_last_step_outputs(last_step_tensor_outputs_dict)  # pylint: disable=protected-access
        return ctx
 def testWrapAListOfTwoTuples(self):
     result = distribute_utils.regroup([("1", "2"), ("3", "4")])
     self.assertIsInstance(result, tuple)
     self.assertLen(result, 2)
     self._is_per_replica(result[0], ("1", "3"), values.PerReplica)
     self._is_per_replica(result[1], ("2", "4"), values.PerReplica)
Example #15
0
def _make_mirrored_indexed_slices(devices, values, indices, dense_shape):
  values = [_make_indexed_slices(values, indices, dense_shape, d)
            for d in devices]
  return distribute_utils.regroup(
      values,
      wrap_class=value_lib.Mirrored)
 def _call_for_each_replica(self, fn, args, kwargs):
   with distribute_lib.ReplicaContext(
       self._container_strategy(),
       replica_id_in_sync_group=constant_op.constant(0, dtypes.int32)):
     # TODO(rchao): Support multi-replica per worker or sync-group.
     return distribute_utils.regroup((fn(*args, **kwargs),))
Example #17
0
        def tpu_function(args, kwargs):
            """TF Function used to replicate the user computation."""
            if kwargs is None:
                kwargs = {}

            # Remove None at the end of args as they are not replicatable
            # If there are None in the middle we can't do anything about it
            # so let those cases fail.
            # For example when Keras model predict is used they pass the targets as
            # None. We want to handle it here so all client libraries don't have to
            # do this as other strategies can handle None values better.
            while args and args[-1] is None:
                args = args[:-1]

            # Used to re-structure flattened output tensors from `tpu.replicate()`
            # into a structured format.
            result = [[]]

            def replicated_fn(replica_id, replica_args, replica_kwargs):
                """Wraps user function to provide replica ID and `Tensor` inputs."""
                with _TPUReplicaContext(strategy,
                                        replica_id_in_sync_group=replica_id):
                    result[0] = fn(*replica_args, **replica_kwargs)
                return result[0]

            replicate_inputs = []  # By replica.
            for i in range(strategy.num_replicas_in_sync):
                replicate_inputs.append([
                    constant_op.constant(i, dtype=dtypes.int32),
                    distribute_utils.select_replica(i, args),
                    distribute_utils.select_replica(i, kwargs)
                ])

            # Construct and pass `maximum_shapes` so that we could support dynamic
            # shapes using dynamic padder.
            if options.experimental_enable_dynamic_batch_size and replicate_inputs:
                maximum_shapes = []
                flattened_list = nest.flatten(replicate_inputs[0])
                for input_tensor in flattened_list:
                    if tensor_util.is_tensor(input_tensor):
                        rank = input_tensor.shape.rank
                    else:
                        rank = np.ndim(input_tensor)
                    maximum_shape = tensor_shape.TensorShape([None] * rank)
                    maximum_shapes.append(maximum_shape)
                maximum_shapes = nest.pack_sequence_as(replicate_inputs[0],
                                                       maximum_shapes)
            else:
                maximum_shapes = None

            if options.experimental_bucketizing_dynamic_shape:
                padding_spec = tpu.PaddingSpec.POWER_OF_TWO
            else:
                padding_spec = None

            with strategy.scope():
                replicate_outputs = tpu.replicate(
                    replicated_fn,
                    replicate_inputs,
                    device_assignment=self._device_assignment,
                    maximum_shapes=maximum_shapes,
                    padding_spec=padding_spec,
                    xla_options=tpu.XLAOptions(
                        use_spmd_for_xla_partitioning=False))

            # Remove all no ops that may have been added during 'tpu.replicate()'
            if isinstance(result[0], list):
                result[0] = [
                    output for output in result[0]
                    if not isinstance(output, ops.Operation)
                ]

            # Workaround for `tpu.replicate` behaviour when single `Tensor` returned.
            if result[0] is None or isinstance(result[0], ops.Operation):
                replicate_outputs = [None] * len(replicate_outputs)
            else:
                replicate_outputs = [
                    nest.pack_sequence_as(result[0],
                                          nest.flatten(replica_output))
                    for replica_output in replicate_outputs
                ]
            return distribute_utils.regroup(replicate_outputs)
 def testOneDevice(self):
     result = distribute_utils.regroup((_nested_value("1"), ))
     # On one device regroup() and select_replica() are basically identity.
     self.assertEqual(_nested_value("1"), result)
     self.assertEqual(_nested_value("1"),
                      distribute_utils.select_replica(0, result))
Example #19
0
    def _do_batch_all_reduce_dense(self, reduce_op, per_replica_values,
                                   experimental_hints):
        """All-reduce across all workers in a batch."""

        batch_size = len(per_replica_values)
        # Pass self._communication to the runtime as a communication hint.
        communication = self._communication.value
        # For now, we use NCCL only when batch_size > 1.
        # TODO(b/132575814): switch to NCCL for all collectives when communication
        # is NCCL.
        if self._communication == CollectiveCommunication.NCCL and batch_size == 1:
            communication = CollectiveCommunication.AUTO.value

        # Reverse the lists so that there's better chance that values follows
        # the order in which they are calculated (e.g. when they're gradients), so
        # as to overlap calculation with communication. However, this may not be
        # optimal for cases like gradients of complicated non-sequential models.
        #
        # Note that we reverse the list before packing so that the first pack won't
        # be too small, since it's more likely for first few packs to have long
        # queuing time due to concurrent intense computation.
        #
        # TODO(b/147393503): explore solutions for optimal ordering.
        packs = cross_device_utils.pack_by_size(
            list(reversed(per_replica_values)),
            experimental_hints.bytes_per_pack)

        if batch_size > 1:
            logging.info(
                "Collective batch_all_reduce: %d all-reduces, num_devices = %d, "
                "group_size = %d, communication_hint = %s, num_packs = %d",
                batch_size, len(self._devices), self._group_size,
                communication, len(packs))
        else:
            logging.log_first_n(
                logging.INFO, "Collective batch_all_reduce: %d all-reduces, "
                "num_devices = %d, group_size = %d, communication_hint = %s, "
                "num_packs = %d" %
                (batch_size, len(self._devices), self._group_size,
                 communication, len(packs)), 10)

        reduced_values = []
        with self._lock:
            for pack in packs:
                # By placing all CollectiveReduce ops in a pack under single name scope,
                # we ensure they will be picked up by the `ScopedAllocator` grappler
                # optimizer and packed into a single all-reduce.
                with ops.name_scope("allreduce"):
                    for per_replica in pack:
                        # Add control dependencies per device from the last gradients to the
                        # current set, in order to serialize NCCL launches.
                        if (communication == CollectiveCommunication.NCCL.value
                                and reduced_values):
                            control_inputs = list(reduced_values[-1])
                        else:
                            control_inputs = None
                        reduced_values.append(
                            cross_device_utils.build_collective_reduce(
                                per_replica.values,
                                self._devices,
                                self._group_size,
                                self._collective_keys,
                                "Add",
                                "Id",
                                communication,
                                control_inputs,
                                executors=self._executors,
                                timeout=experimental_hints.timeout_seconds))

        for e in self._executors:
            e.wait()

        mirrored = []
        # Reverse the order of reduced value to recover the order in the input.
        for value in reversed(reduced_values):
            if reduce_op == reduce_util.ReduceOp.MEAN:
                for i, v in enumerate(value):
                    with ops.device(v.device):
                        value[i] = v / self._group_size
            mirrored.append(
                distribute_utils.regroup(value, wrap_class=value_lib.Mirrored))
        return mirrored