Esempio n. 1
0
    def all_to_all(x,
                   concat_dimension,
                   split_dimension,
                   split_count,
                   group_assignment=None,
                   name=None):
        """Exchange data across TPU replicas.

    Args:
      x: The local tensor.
      concat_dimension: The dimension number to concatenate.
      split_dimension: The dimension number to split.
      split_count: The number of splits, this number must equal to the sub-group
        size(group_assignment.get_shape()[1])
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is concatenated by data from different replicas.
    """
        if group_assignment is None:
            group_assignment = _create_default_group_assignment()
        return gen_tpu_ops.all_to_all(x,
                                      group_assignment,
                                      concat_dimension=concat_dimension,
                                      split_dimension=split_dimension,
                                      split_count=split_count,
                                      name=name)
Esempio n. 2
0
  def all_to_all(x,
                 concat_dimension,
                 split_dimension,
                 split_count,
                 group_assignment=None,
                 name=None):
    """Exchange data across TPU replicas.

    Args:
      x: The local tensor.
      concat_dimension: The dimension number to concatenate.
      split_dimension: The dimension number to split.
      split_count: The number of splits, this number must equal to the sub-group
        size(group_assignment.get_shape()[1])
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is concatenated by data from different replicas.
    """
    if group_assignment is None:
      group_assignment = _create_default_group_assignment()
    return gen_tpu_ops.all_to_all(
        x,
        group_assignment,
        concat_dimension=concat_dimension,
        split_dimension=split_dimension,
        split_count=split_count,
        name=name)
Esempio n. 3
0
 def _all_to_all_grad(op, grad):
     # The gradient of a all-to-all is also a all-to-all but the
     # split_dimension and concat_dimension is swapped.
     # The graident with respect to group_assignment is None.
     return [
         gen_tpu_ops.all_to_all(
             grad,
             op.inputs[1],
             concat_dimension=op.get_attr("split_dimension"),
             split_dimension=op.get_attr("concat_dimension"),
             split_count=op.get_attr("split_count")), None
     ]
Esempio n. 4
0
 def _all_to_all_grad(op, grad):
   # The gradient of a all-to-all is also a all-to-all but the
   # split_dimension and concat_dimension is swapped.
   # The graident with respect to group_assignment is None.
   return [
       gen_tpu_ops.all_to_all(
           grad,
           op.inputs[1],
           concat_dimension=op.get_attr("split_dimension"),
           split_dimension=op.get_attr("concat_dimension"),
           split_count=op.get_attr("split_count")), None
   ]