示例#1
0
  def allreduce(self, x, mesh_axes, reduction_fn_string):
    """Grouped allreduce, (summed across the given dimensions).

    Args:
      x: a LaidOutTensor
      mesh_axes: a list of integers
      reduction_fn_string: "SUM"
    Returns:
      a LaidOutTensor
    Raises:
      ValueError: if the reduction is not yet implemented.
    """
    if not mesh_axes:
      return x
    x = x.to_laid_out_tensor()
    if reduction_fn_string == "SUM":
      group_assignment = self._create_group_assignment(mesh_axes)
      return self.LaidOutTensor(
          [tpu_ops.cross_replica_sum(x.one_slice, group_assignment)])
    else:
      for axis in mesh_axes:
        x = self.allconcat(x, axis, 0, stack=True)
        x = self.LaidOutTensor(
            [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)])
      return x
示例#2
0
    def mtf_model_fn(self, features, mesh):
        hparams = self._hparams
        hparams.batch_size = 10
        hparams.io_size = 4
        hparams.hidden_size = 2
        tf_x = tf.matmul(
            tf.reshape(tf.lin_space(0., 1.0, hparams.batch_size),
                       [hparams.batch_size, 1]),
            tf.reshape(tf.lin_space(0., 1.0, hparams.io_size),
                       [1, hparams.io_size]))
        # tf_x = tf.random_uniform([hparams.batch_size, hparams.io_size])

        hidden_1_variable = tf.get_variable(
            "a",
            shape=[hparams.io_size, hparams.hidden_size],
            initializer=tf.random_normal_initializer())
        hidden_2_variable = tf.get_variable(
            "b",
            shape=[hparams.hidden_size, hparams.io_size],
            initializer=tf.random_normal_initializer())

        hidden_layer_1 = tf.matmul(tf_x, hidden_1_variable)
        hidden_layer_2 = tf.matmul(hidden_layer_1, hidden_2_variable)
        hidden_layer_2 = tpu_ops.cross_replica_sum(hidden_layer_2)
        loss = tf.reduce_mean(tf.square(hidden_layer_2 - tf_x))
        return None, loss
示例#3
0
  def _reduce(self, aggregation, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if aggregation == vs.VariableAggregation.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self.num_towers)
      elif aggregation != vs.VariableAggregation.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self.get_host_cpu_device(0))
    else:
      raise ValueError('Multiple devices are not supported for TPUStrategy')

    if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
      return value[0]
    output = math_ops.add_n(value)
    if aggregation == vs.VariableAggregation.MEAN:
      return output * (1. / len(value))
    return output
示例#4
0
    def _reduce(self, aggregation, value, destinations):
        graph = ops.get_default_graph()
        cf_context = graph._get_control_flow_context()  # pylint: disable=protected-access
        # If we're inside the ReplicateContext, reduction should be done using
        # CrossReplicaSum while outside we can directly use an add_n op.
        while cf_context:
            if isinstance(cf_context, tpu.TPUReplicateContext):
                if aggregation == vs.VariableAggregation.MEAN:
                    # TODO(jhseu):  Revisit once we support model-parallelism.
                    value *= (1. / self.num_towers)
                return tpu_ops.cross_replica_sum(value)
            cf_context = cf_context.outer_context

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host)
        else:
            raise ValueError(
                'Multiple devices are not supported for TPUStrategy')

        output = math_ops.add_n(value)
        if aggregation == vs.VariableAggregation.MEAN:
            return output * (1. / len(value))
        return output
示例#5
0
  def _reduce(self, aggregation, value, destinations):
    graph = ops.get_default_graph()
    cf_context = graph._get_control_flow_context()  # pylint: disable=protected-access
    # If we're inside the ReplicateContext, reduction should be done using
    # CrossReplicaSum while outside we can directly use an add_n op.
    while cf_context:
      if isinstance(cf_context, tpu.TPUReplicateContext):
        if aggregation == vs.VariableAggregation.MEAN:
          # TODO(jhseu):  Revisit once we support model-parallelism.
          value *= (1. / self.num_towers)
        return tpu_ops.cross_replica_sum(value)
      cf_context = cf_context.outer_context

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_tower_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host)
    else:
      raise ValueError('Multiple devices are not supported for TPUStrategy')

    output = math_ops.add_n(value)
    if aggregation == vs.VariableAggregation.MEAN:
      return output * (1. / len(value))
    return output
示例#6
0
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, self._device_map, value, destinations)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_device_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host_device)
        else:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        output = math_ops.add_n(value)
        if reduce_op == reduce_util.ReduceOp.MEAN:
            return output * (1. / len(value))
        return output
示例#7
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
示例#8
0
  def apply_gradients(self, grads_and_vars, global_step=None, name=None):
    """Apply gradients to variables.

    Calls tpu_ops.cross_replica_sum() to sum gradient contributions across
    replicas, and then applies the real optimizer.

    Args:
      grads_and_vars: List of (gradient, variable) pairs as returned by
        compute_gradients().
      global_step: Optional Variable to increment by one after the
        variables have been updated.
      name: Optional name for the returned operation.  Default to the
        name passed to the Optimizer constructor.

    Returns:
      An `Operation` that applies the gradients. If `global_step` was not None,
      that operation also increments `global_step`.

    Raises:
      ValueError: If the grads_and_vars is malformed.
    """
    summed_grads_and_vars = []
    for (grad, var) in grads_and_vars:
      if grad is None:
        summed_grads_and_vars.append((grad, var))
      else:
        summed_grads_and_vars.append((tpu_ops.cross_replica_sum(grad), var))
    return self._opt.apply_gradients(summed_grads_and_vars, global_step, name)
示例#9
0
    def allreduce(self, x, mesh_axes, reduction_fn_string):
        """Grouped allreduce, (summed across the given dimensions).

    Args:
      x: a LaidOutTensor
      mesh_axes: a list of integers
      reduction_fn_string: "SUM"
    Returns:
      a LaidOutTensor
    Raises:
      ValueError: if the reduction is not yet implemented.
    """
        if not mesh_axes:
            return x
        x = x.to_laid_out_tensor()
        if reduction_fn_string == "SUM":
            group_assignment = self._create_group_assignment(mesh_axes)
            tf_in = x.one_slice
            dtype = tf_in.dtype
            if not (dtype == tf.float32 or dtype == tf.bfloat16):
                tf.logging.info("Casting %s to float32 for allreduce" %
                                tf_in.dtype)
                tf_in = tf.cast(tf_in, tf.float32)
            tf_out = tpu_ops.cross_replica_sum(tf_in, group_assignment)
            if tf_out.dtype != dtype:
                tf_out = tf.cast(tf_out, dtype)
            return self.LaidOutTensor([tf_out])
        else:
            for axis in mesh_axes:
                x = self.allconcat(x, axis, 0, stack=True)
                x = self.LaidOutTensor(
                    [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)])
            return x
示例#10
0
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_device_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(self._host_device)
        else:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        output = math_ops.add_n(value)
        if reduce_op == reduce_util.ReduceOp.MEAN:
            return output * (1. / len(value))
        return output
示例#11
0
    def _reduce(self, aggregation, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if aggregation == vs.VariableAggregation.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self.num_towers)
            elif aggregation != vs.VariableAggregation.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        # Validate that the destination is same as the host device
        # Note we don't do this when in replicate context as the reduction is
        # performed on the TPU device itself.
        devices = cross_tower_ops_lib.get_devices_from(destinations)
        if len(devices) == 1:
            assert device_util.canonicalize(
                devices[0]) == device_util.canonicalize(
                    self.get_host_cpu_device(0))
        else:
            raise ValueError(
                'Multiple devices are not supported for TPUStrategy')

        if aggregation == vs.VariableAggregation.ONLY_FIRST_TOWER:
            return value[0]
        output = math_ops.add_n(value)
        if aggregation == vs.VariableAggregation.MEAN:
            return output * (1. / len(value))
        return output
示例#12
0
  def allreduce(self, x, mesh_axes, reduction_fn_string):
    """Grouped allreduce, (summed across the given dimensions).

    Args:
      x: a LaidOutTensor
      mesh_axes: a list of integers
      reduction_fn_string: "SUM"
    Returns:
      a LaidOutTensor
    Raises:
      ValueError: if the reduction is not yet implemented.
    """
    if not mesh_axes:
      return x
    x = x.to_laid_out_tensor()
    if reduction_fn_string == "SUM":
      partitioning = [
          mtf.pnum_to_group(self.shape, mesh_axes, pnum)
          for pnum in xrange(self.size)]
      return self.LaidOutTensor(
          [tpu_ops.cross_replica_sum(x.one_slice, partitioning)])
    else:
      for axis in mesh_axes:
        x = self.allconcat(x, axis, 0, stack=True)
        x = self.LaidOutTensor(
            [mtf.reduction_fn(reduction_fn_string)(x.one_slice, 0)])
      return x
示例#13
0
def cross_replica_average(inputs,
                          num_shards=None,
                          num_shards_per_group=None,
                          physical_shape=None,
                          tile_shape=None,
                          use_spatial_partitioning=False):
    """Customized cross replica sum op."""
    # if num_shards_per_group is defined, apply distributed batch norm.
    group_assignment = None

    if num_shards_per_group > 0:
        if num_shards % num_shards_per_group != 0:
            raise ValueError(
                'num_shards: %d mod num_shards_per_group: %d, should be 0' %
                (num_shards, num_shards_per_group))

    num_groups = num_shards // num_shards_per_group

    if physical_shape is not None and tile_shape is not None:
        if use_spatial_partitioning:
            group_assignment = spatial_partitioning_group_assignment(
                physical_shape, tile_shape, num_groups)
        else:
            group_assignment = normal_group_assignment(physical_shape,
                                                       tile_shape, num_groups)
    else:
        group_assignment = [
            [  # pylint: disable=g-complex-comprehension
                x for x in range(num_shards) if x // num_shards_per_group == y
            ] for y in range(num_groups)
        ]

    return tpu_ops.cross_replica_sum(inputs, group_assignment) / math_ops.cast(
        num_shards_per_group, inputs.dtype)
示例#14
0
def cross_replica_average(inputs, num_shards, distributed_group_size):
    """Calculates the average value of inputs tensor across TPU replicas."""
    group_assignment = None
    if num_shards is not None and distributed_group_size != num_shards:
        group_assignment = [
            i // distributed_group_size for i in range(num_shards)
        ]
    return tpu_ops.cross_replica_sum(inputs, group_assignment) / tf.cast(
        distributed_group_size, inputs.dtype)
示例#15
0
 def _cross_replica_sum(self, grads_and_vars):
     summed_grads_and_vars = []
     for (grad, var) in grads_and_vars:
         if grad is None:
             summed_grads_and_vars.append((grad, var))
         else:
             with ops.colocate_with(grad):
                 summed_grads_and_vars.append(
                     (tpu_ops.cross_replica_sum(grad,
                                                self._group_assignment),
                      var))
     return summed_grads_and_vars
示例#16
0
def cross_replica_average(inputs, num_shards, distributed_group_size):
  """Calculates the average value of inputs tensor across TPU replicas."""
  group_assignment = None
  if num_shards is not None and distributed_group_size != num_shards:
    group_size = distributed_group_size
    group_assignment = []
    for g in range(num_shards // group_size):
      replica_ids = [g * group_size + i for i in range(group_size)]
      group_assignment.append(replica_ids)

  return tpu_ops.cross_replica_sum(inputs, group_assignment) / tf.cast(
      distributed_group_size, inputs.dtype)
示例#17
0
 def _cross_replica_average(self, t, num_shards_per_group):
     """Calculates the average value of input tensor across TPU replicas."""
     num_shards = tpu_function.get_tpu_context().number_of_shards
     group_assignment = None
     if num_shards_per_group > 1:
         if num_shards % num_shards_per_group != 0:
             raise ValueError('num_shards: %d mod shards_per_group: %d, should be 0'
                              % (num_shards, num_shards_per_group))
         num_groups = num_shards // num_shards_per_group
         group_assignment = [[
             x for x in range(num_shards) if x // num_shards_per_group == y
         ] for y in range(num_groups)]
     return tpu_ops.cross_replica_sum(t, group_assignment) / tf.cast(
         num_shards_per_group, t.dtype)
示例#18
0
def cross_replica_average(inputs, num_shards=None, num_shards_per_group=None):
    """Customized cross replica sum op."""
    # if num_shards_per_group is defined, apply distributed batch norm.
    group_assignment = None
    if num_shards_per_group > 0:
        if num_shards % num_shards_per_group != 0:
            raise ValueError(
                'num_shards: %d mod num_shards_per_group: %d, should be 0' %
                (num_shards, num_shards_per_group))
        num_groups = num_shards // num_shards_per_group
        group_assignment = [[
            x for x in range(num_shards) if x // num_shards_per_group == y
        ] for y in range(num_groups)]

    return tpu_ops.cross_replica_sum(inputs, group_assignment) / math_ops.cast(
        num_shards_per_group, inputs.dtype)
示例#19
0
文件: utils.py 项目: gd-zhang/ACKTR
def cross_replica_mean(tensor, name=None):
    """Takes mean value of a Tensor across all TPU cores.
  Args:
    tensor: Tensor to be synchronized.
    name: None or string. Name of Op.
  Returns:
    Average of Tensor across all TPU cores.
  Raises:
    ValueError: If called outside of TPU context.
  """
    with ops.name_scope(name, "cross_replica_mean", [tensor]):
        num_shards = tpu_function.get_tpu_context().number_of_shards
        if num_shards is None:
            raise ValueError(
                "Cannot take cross_replica_mean() outside of TPU Context.")
        if num_shards == 1:
            return tensor
        return tpu_ops.cross_replica_sum(tensor / num_shards)
示例#20
0
def cross_replica_mean(tensor, name=None):
  """Takes mean value of a Tensor across all TPU cores.

  Args:
    tensor: Tensor to be synchronized.
    name: None or string. Name of Op.

  Returns:
    Average of Tensor across all TPU cores.

  Raises:
    ValueError: If called outside of TPU context.
  """
  with ops.name_scope(name, "cross_replica_mean", [tensor]):
    num_shards = tpu_function.get_tpu_context().number_of_shards
    if num_shards is None:
      raise ValueError(
          "Cannot take cross_replica_mean() outside of TPU Context.")
    if num_shards == 1:
      return tensor
    return tpu_ops.cross_replica_sum(tensor / num_shards)
示例#21
0
    def _reduce_to(self, reduce_op, value, destinations):
        if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
            if reduce_op == reduce_util.ReduceOp.MEAN:
                # TODO(jhseu):  Revisit once we support model-parallelism.
                value *= (1. / self._num_replicas_in_sync)
            elif reduce_op != reduce_util.ReduceOp.SUM:
                raise NotImplementedError(
                    "Currently only support sum & mean in TPUStrategy.")
            return tpu_ops.cross_replica_sum(value)

        if not isinstance(value, values.DistributedValues):
            # This function handles reducing values that are not PerReplica or
            # Mirrored values. For example, the same value could be present on all
            # replicas in which case `value` would be a single value or value could
            # be 0.
            return cross_device_ops_lib.reduce_non_distributed_value(
                reduce_op, self._device_map, value, destinations)

        devices = cross_device_ops_lib.get_devices_from(destinations)
        if len(devices) != 1:
            raise ValueError(
                "Multiple devices are not supported for TPUStrategy")

        # Always performs the reduction on the TPU host.
        with ops.device(self._host_device):
            output = math_ops.add_n(value.values)
            if reduce_op == reduce_util.ReduceOp.MEAN:
                output *= (1. / len(value.values))

        # If necessary, copy to requested destination.
        dest_canonical = device_util.canonicalize(devices[0])
        host_canonical = device_util.canonicalize(self._host_device)

        if dest_canonical != host_canonical:
            with ops.device(devices[0]):
                output = array_ops.identity(output)

        return output
示例#22
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    if not isinstance(value, values.DistributedValues):
      # This function handles reducing values that are not PerReplica or
      # Mirrored values. For example, the same value could be present on all
      # replicas in which case `value` would be a single value or value could
      # be 0.
      return cross_device_ops_lib.reduce_non_distributed_value(
          reduce_op, self._device_map, value, destinations)

    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) != 1:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    # Always performs the reduction on the TPU host.
    with ops.device(self._host_device):
      output = math_ops.add_n(value.values)
      if reduce_op == reduce_util.ReduceOp.MEAN:
        output *= (1. / len(value.values))

    # If necessary, copy to requested destination.
    dest_canonical = device_util.canonicalize(devices[0])
    host_canonical = device_util.canonicalize(self._host_device)

    if dest_canonical != host_canonical:
      with ops.device(devices[0]):
        output = array_ops.identity(output)

    return output
示例#23
0
  def _reduce_to(self, reduce_op, value, destinations):
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      if reduce_op == reduce_util.ReduceOp.MEAN:
        # TODO(jhseu):  Revisit once we support model-parallelism.
        value *= (1. / self._num_replicas_in_sync)
      elif reduce_op != reduce_util.ReduceOp.SUM:
        raise NotImplementedError(
            "Currently only support sum & mean in TPUStrategy.")
      return tpu_ops.cross_replica_sum(value)

    # Validate that the destination is same as the host device
    # Note we don't do this when in replicate context as the reduction is
    # performed on the TPU device itself.
    devices = cross_device_ops_lib.get_devices_from(destinations)
    if len(devices) == 1:
      assert device_util.canonicalize(devices[0]) == device_util.canonicalize(
          self._host_device)
    else:
      raise ValueError("Multiple devices are not supported for TPUStrategy")

    output = math_ops.add_n(value)
    if reduce_op == reduce_util.ReduceOp.MEAN:
      return output * (1. / len(value))
    return output
示例#24
0
 def get_gradients(self, loss, params):
     num_shards = tpu_function.get_tpu_context().number_of_shards
     grads = super(KerasCrossShardOptimizer,
                   self).get_gradients(loss, params)
     return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
示例#25
0
 def _reduce(self, method_string, value, destinations):
     del destinations  # TPU is graph mode only.  Rely on implicit Send/Recv.
     if method_string == 'mean':
         # TODO(jhseu):  Revisit once we support model-parallelism.
         value *= (1. / self._num_cores_per_host)
     return tpu_ops.cross_replica_sum(value)
示例#26
0
 def _reduce(self, method_string, value, destinations):
   del destinations  # TPU is graph mode only.  Rely on implicit Send/Recv.
   if method_string == 'mean':
     # TODO(jhseu):  Revisit once we support model-parallelism.
     value *= (1. / self._num_cores_per_host)
   return tpu_ops.cross_replica_sum(value)
示例#27
0
 def get_gradients(self, loss, params):
   num_shards = tpu_function.get_tpu_context().number_of_shards
   grads = super(KerasCrossShardOptimizer, self).get_gradients(loss, params)
   return [tpu_ops.cross_replica_sum(grad) / num_shards for grad in grads]
 def _reduce(self, aggregation, value, destinations):
   del destinations  # TPU is graph mode only.  Rely on implicit Send/Recv.
   if aggregation == vs.VariableAggregation.MEAN:
     # TODO(jhseu):  Revisit once we support model-parallelism.
     value *= (1. / self._num_cores_per_host)
   return tpu_ops.cross_replica_sum(value)