Example #1
0
  def cross_replica_sum(x, group_assignment=None, name=None):
    """Sum the input tensor across replicas according to group_assignment.

    Args:
      x: The local tensor to the sum.
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is summed across replicas.
    """
    if group_assignment is None:
      group_assignment = _create_default_group_assignment()

    return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
Example #2
0
    def cross_replica_sum(x, group_assignment=None, name=None):
        """Sum the input tensor across replicas according to group_assignment.

    Args:
      x: The local tensor to the sum.
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is summed across replicas.
    """
        if group_assignment is None:
            group_assignment = _create_default_group_assignment()

        return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
    def cross_replica_sum(x, group_assignment=None, name=None):
        """Sum the input tensor accorss replicas according to group_assignment.

    Args:
      x: The local tensor to the sum.
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is summed across replicas.
    """
        if group_assignment is None:
            num_shards = tpu_function.get_tpu_context().number_of_shards
            if num_shards is None:
                logging.warning(
                    "cross_replica_sum should be used within a tpu_shard_context, but "
                    "got unset number_of_shards. Assuming 1.")
                num_shards = 1
            group_assignment = [list(range(num_shards))]

        return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
Example #4
0
 def _cross_replica_sum_grad(op, grad):
     # The gradient of a cross replica sum is also a cross-replica sum.
     # The graident with respect to group_assignment is None.
     return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
Example #5
0
 def _cross_replica_sum_grad(op, grad):
   # The gradient of a cross replica sum is also a cross-replica sum.
   # The graident with respect to group_assignment is None.
   return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
Example #6
0
 def _cross_replica_sum_grad(op, grad):
     # The gradient of a cross replica sum is also a cross-replica sum.
     return gen_tpu_ops.cross_replica_sum(grad,
                                          op.get_attr("group_assignment"))
Example #7
0
 def _cross_replica_sum_grad(op, grad):
   del op  # Unused
   # The gradient of a cross replica sum is also a cross-replica sum.
   return gen_tpu_ops.cross_replica_sum(grad)
Example #8
0
 def _cross_replica_sum_grad(op, grad):
     del op  # Unused
     # The gradient of a cross replica sum is also a cross-replica sum.
     return gen_tpu_ops.cross_replica_sum(grad)
Example #9
0
 def _cross_replica_sum_grad(op, grad):
   # The gradient of a cross replica sum is also a cross-replica sum.
   return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))