Esempio n. 1
0
 def test_all_to_all_group_assignment_wrong_shape(self):
     with self.assertRaisesRegex(ValueError,
                                 "group_assignment must have rank 2"):
         tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543],
                            group_assignment=[1, -1],
                            concat_dimension=0,
                            split_dimension=0,
                            split_count=2)
Esempio n. 2
0
 def test_all_to_all_zero_split_count(self):
     with self.assertRaisesRegex(ValueError,
                                 "split_count 0 must at least be one"):
         tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543],
                            group_assignment=[1, -1],
                            concat_dimension=0,
                            split_dimension=0,
                            split_count=0)
Esempio n. 3
0
 def test_all_to_all_split_count_not_divide_input_shape(self):
     with self.assertRaisesRegex(
             ValueError,
             "input dimension 3 not divisible by split_count 2"):
         tpu_ops.all_to_all(x=[[0.0], [0.1652], [0.6543]],
                            group_assignment=[[0, 1], [2, 3]],
                            concat_dimension=1,
                            split_dimension=0,
                            split_count=2)
Esempio n. 4
0
 def test_all_to_all_split_count_not_equal_to_group_assignment_shape(self):
     with self.assertRaisesRegex(
             ValueError,
             "split_count 1 must equal the size of the second dimension "
             "of group_assignment 2"):
         tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543],
                            group_assignment=[[0, 1], [2, 3]],
                            concat_dimension=0,
                            split_dimension=0,
                            split_count=1)
Esempio n. 5
0
    def alltoall(self, x, mesh_axis, split_axis, concat_axis):
        """Grouped alltoall (like MPI alltoall with splitting and concatenation).

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer the mesh axis along which to group
      split_axis: an integer (the Tensor axis along which to split)
      concat_axis: an integer (the Tensor axis along which to concatenate)
    Returns:
      a LaidOutTensor
    """
        x = x.to_laid_out_tensor()
        t = x.one_slice
        group_assignment = self._create_group_assignment([mesh_axis])
        dtype = t.dtype
        if dtype == tf.float32:
            # There seems to be a bug with float32 alltoall.
            # Do it in bfloat16 until the bug is fixed.
            # TODO(noam): file a bug
            t = tf.to_bfloat16(t)
        t = tpu_ops.all_to_all(t,
                               concat_dimension=concat_axis,
                               split_dimension=split_axis,
                               split_count=len(group_assignment[0]),
                               group_assignment=group_assignment)
        t = tf.cast(t, dtype)
        x = self.LaidOutTensor([t])
        return x
Esempio n. 6
0
  def _broadcast_to(self, tensor, destinations):
    del destinations
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
      result = tpu_ops.all_to_all(
          broadcast_tensor,
          concat_dimension=0,
          split_dimension=0,
          split_count=self._num_replicas_in_sync)

      # This uses the broadcasted value from the first replica because the only
      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
      return result[0]
    return tensor
Esempio n. 7
0
  def _broadcast_to(self, tensor, destinations):
    del destinations
    # This is both a fast path for Python constants, and a way to delay
    # converting Python values to a tensor until we know what type it
    # should be converted to. Otherwise we have trouble with:
    #   global_step.assign_add(1)
    # since the `1` gets broadcast as an int32 but global_step is int64.
    if isinstance(tensor, (float, int)):
      return tensor
    if values._enclosing_tpu_context() is not None:  # pylint: disable=protected-access
      broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)]
      result = tpu_ops.all_to_all(
          broadcast_tensor,
          concat_dimension=0,
          split_dimension=0,
          split_count=self._num_replicas_in_sync)

      # This uses the broadcasted value from the first replica because the only
      # caller of this is for ONLY_FIRST_REPLICA variables aggregation.
      return result[0]
    return tensor