def alltoall(self, x, mesh_axis, split_axis, concat_axis): """Grouped alltoall (like MPI alltoall with splitting and concatenation). Args: x: a LaidOutTensor mesh_axis: an integer the mesh axis along which to group split_axis: an integer (the Tensor axis along which to split) concat_axis: an integer (the Tensor axis along which to concatenate) Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice group_assignment = self._create_group_assignment([mesh_axis]) dtype = t.dtype if dtype == tf.float32: # There seems to be a bug with float32 alltoall. # Do it in bfloat16 until the bug is fixed. # TODO(noam): file a bug t = tf.to_bfloat16(t) t = tpu_ops.all_to_all(t, concat_dimension=concat_axis, split_dimension=split_axis, split_count=len(group_assignment[0]), group_assignment=group_assignment) t = tf.cast(t, dtype) x = self.LaidOutTensor([t]) return x
def _to_bfloat16_unbiased(x, noise): """Convert a float32 to a bfloat16 using randomized roundoff. Args: x: A float32 Tensor. noise: a float32 Tensor with values in [0, 1), broadcastable to tf.shape(x) Returns: A float32 Tensor. """ x_sign = tf.sign(x) # Make sure x is positive. If it is zero, the two candidates are identical. x = x * x_sign + 1e-30 cand1 = tf.to_bfloat16(x) cand1_f = tf.to_float(cand1) # This relies on the fact that for a positive bfloat16 b, # b * 1.005 gives you the next higher bfloat16 and b*0.995 gives you the # next lower one. Both 1.005 and 0.995 are ballpark estimation. cand2 = tf.to_bfloat16( tf.where(tf.greater(x, cand1_f), cand1_f * 1.005, cand1_f * 0.995)) ret = _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2) return ret * tf.to_bfloat16(x_sign)