def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor across replicas according to group_assignment. Args: x: The local tensor to the sum. group_assignment: Optional 2d int32 lists with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup. name: Optional op name. Returns: A `Tensor` which is summed across replicas. """ if group_assignment is None: group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor across replicas according to group_assignment. Args: x: The local tensor to the sum. group_assignment: Optional 2d int32 lists with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup. name: Optional op name. Returns: A `Tensor` which is summed across replicas. """ if group_assignment is None: group_assignment = _create_default_group_assignment() return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def cross_replica_sum(x, group_assignment=None, name=None): """Sum the input tensor accorss replicas according to group_assignment. Args: x: The local tensor to the sum. group_assignment: Optional 2d int32 lists with shape [num_groups, num_replicas_per_group]. `group_assignment[i]` represents the replica ids in the ith subgroup. name: Optional op name. Returns: A `Tensor` which is summed across replicas. """ if group_assignment is None: num_shards = tpu_function.get_tpu_context().number_of_shards if num_shards is None: logging.warning( "cross_replica_sum should be used within a tpu_shard_context, but " "got unset number_of_shards. Assuming 1.") num_shards = 1 group_assignment = [list(range(num_shards))] return gen_tpu_ops.cross_replica_sum(x, group_assignment, name=name)
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. # The graident with respect to group_assignment is None. return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. # The graident with respect to group_assignment is None. return [gen_tpu_ops.cross_replica_sum(grad, op.inputs[1]), None]
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))
def _cross_replica_sum_grad(op, grad): del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. return gen_tpu_ops.cross_replica_sum(grad)
def _cross_replica_sum_grad(op, grad): del op # Unused # The gradient of a cross replica sum is also a cross-replica sum. return gen_tpu_ops.cross_replica_sum(grad)
def _cross_replica_sum_grad(op, grad): # The gradient of a cross replica sum is also a cross-replica sum. return gen_tpu_ops.cross_replica_sum(grad, op.get_attr("group_assignment"))