def test_all_to_all_group_assignment_wrong_shape(self): with self.assertRaisesRegex(ValueError, "group_assignment must have rank 2"): tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543], group_assignment=[1, -1], concat_dimension=0, split_dimension=0, split_count=2)
def test_all_to_all_zero_split_count(self): with self.assertRaisesRegex(ValueError, "split_count 0 must at least be one"): tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543], group_assignment=[1, -1], concat_dimension=0, split_dimension=0, split_count=0)
def test_all_to_all_split_count_not_divide_input_shape(self): with self.assertRaisesRegex( ValueError, "input dimension 3 not divisible by split_count 2"): tpu_ops.all_to_all(x=[[0.0], [0.1652], [0.6543]], group_assignment=[[0, 1], [2, 3]], concat_dimension=1, split_dimension=0, split_count=2)
def test_all_to_all_split_count_not_equal_to_group_assignment_shape(self): with self.assertRaisesRegex( ValueError, "split_count 1 must equal the size of the second dimension " "of group_assignment 2"): tpu_ops.all_to_all(x=[0.0, 0.1652, 0.6543], group_assignment=[[0, 1], [2, 3]], concat_dimension=0, split_dimension=0, split_count=1)
def alltoall(self, x, mesh_axis, split_axis, concat_axis): """Grouped alltoall (like MPI alltoall with splitting and concatenation). Args: x: a LaidOutTensor mesh_axis: an integer the mesh axis along which to group split_axis: an integer (the Tensor axis along which to split) concat_axis: an integer (the Tensor axis along which to concatenate) Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice group_assignment = self._create_group_assignment([mesh_axis]) dtype = t.dtype if dtype == tf.float32: # There seems to be a bug with float32 alltoall. # Do it in bfloat16 until the bug is fixed. # TODO(noam): file a bug t = tf.to_bfloat16(t) t = tpu_ops.all_to_all(t, concat_dimension=concat_axis, split_dimension=split_axis, split_count=len(group_assignment[0]), group_assignment=group_assignment) t = tf.cast(t, dtype) x = self.LaidOutTensor([t]) return x
def _broadcast_to(self, tensor, destinations): del destinations if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] result = tpu_ops.all_to_all( broadcast_tensor, concat_dimension=0, split_dimension=0, split_count=self._num_replicas_in_sync) # This uses the broadcasted value from the first replica because the only # caller of this is for ONLY_FIRST_REPLICA variables aggregation. return result[0] return tensor
def _broadcast_to(self, tensor, destinations): del destinations # This is both a fast path for Python constants, and a way to delay # converting Python values to a tensor until we know what type it # should be converted to. Otherwise we have trouble with: # global_step.assign_add(1) # since the `1` gets broadcast as an int32 but global_step is int64. if isinstance(tensor, (float, int)): return tensor if values._enclosing_tpu_context() is not None: # pylint: disable=protected-access broadcast_tensor = [tensor for _ in range(self._num_replicas_in_sync)] result = tpu_ops.all_to_all( broadcast_tensor, concat_dimension=0, split_dimension=0, split_count=self._num_replicas_in_sync) # This uses the broadcasted value from the first replica because the only # caller of this is for ONLY_FIRST_REPLICA variables aggregation. return result[0] return tensor