def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor across replicas according to group_assignment. Args: x: The local tensor to the sum. group_assignment: Optional 2d int32 lists with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup. name: Optional op name. Returns: A `Tensor` which is summed across replicas. """ if group_assignment is None: group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor across replicas according to group_assignment. Args: x: The local tensor to the sum. group_assignment: Optional 2d int32 lists with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup. name: Optional op name. Returns: A `Tensor` which is summed across replicas. """ if group_assignment is None: group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. # The gradient with respect to group_assignment is None. return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. # The gradient with respect to group_assignment is None. return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]