Beispiel #1
0
    def all_to_all(x,
                   concat_dimension,
                   split_dimension,
                   split_count,
                   group_assignment=None,
                   name=None):
        """Exchange data across TPU replicas.

    Args:
      x: The local tensor.
      concat_dimension: The dimension number to concatenate.
      split_dimension: The dimension number to split.
      split_count: The number of splits, this number must equal to the sub-group
        size(group_assignment.get_shape()[1])
      group_assignment: Optional 2d int32 lists with shape [num_groups,
        num_replicas_per_group]. `group_assignment[i]` represents the replica
        ids in the ith subgroup.
      name: Optional op name.

    Returns:
      A `Tensor` which is concatenated by data from different replicas.
    """
        if group_assignment is None:
            group_assignment = _create_default_group_assignment()
        return gen_tpu_ops.all_to_all(x,
                                      group_assignment,
                                      concat_dimension=concat_dimension,
                                      split_dimension=split_dimension,
                                      split_count=split_count,
                                      name=name)
Beispiel #2
0
def all_to_all(x,
               concat_dimension,
               split_dimension,
               split_count,
               group_assignment=None,
               name=None):
  """Exchange data across TPU replicas.

  Args:
    x: The local tensor.
    concat_dimension: The dimension number to concatenate.
    split_dimension: The dimension number to split.
    split_count: The number of splits, this number must equal to the sub-group
      size(group_assignment.get_shape()[1])
    group_assignment: Optional 2d int32 lists with shape [num_groups,
      num_replicas_per_group]. `group_assignment[i]` represents the replica
      ids in the ith subgroup.
    name: Optional op name.

  Returns:
    A `Tensor` which is concatenated by data from different replicas.
  """
  if group_assignment is None:
    group_assignment = _create_default_group_assignment()
  return gen_tpu_ops.all_to_all(
      x,
      group_assignment,
      concat_dimension=concat_dimension,
      split_dimension=split_dimension,
      split_count=split_count,
      name=name)
Beispiel #3
0
def _all_to_all_grad(op, grad):
    # The gradient of a all-to-all is also a all-to-all but the
    # split_dimension and concat_dimension is swapped.
    # The gradient with respect to group_assignment is None.
    return [
        gen_tpu_ops.all_to_all(grad,
                               op.inputs[1],
                               concat_dimension=op.get_attr("split_dimension"),
                               split_dimension=op.get_attr("concat_dimension"),
                               split_count=op.get_attr("split_count")), None
    ]
Beispiel #4
0
def _all_to_all_grad(op, grad):
  # The gradient of a all-to-all is also a all-to-all but the
  # split_dimension and concat_dimension is swapped.
  # The graident with respect to group_assignment is None.
  return [
      gen_tpu_ops.all_to_all(
          grad,
          op.inputs[1],
          concat_dimension=op.get_attr("split_dimension"),
          split_dimension=op.get_attr("concat_dimension"),
          split_count=op.get_attr("split_count")), None
  ]