Example #1
0
    def alltoall(self, x, mesh_axis, split_axis, concat_axis):
        """Grouped alltoall (like MPI alltoall with splitting and concatenation).

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer the mesh axis along which to group
      split_axis: an integer (the Tensor axis along which to split)
      concat_axis: an integer (the Tensor axis along which to concatenate)
    Returns:
      a LaidOutTensor
    """
        x = x.to_laid_out_tensor()
        t = x.one_slice
        group_assignment = self._create_group_assignment([mesh_axis])
        dtype = t.dtype
        if dtype == tf.float32:
            # There seems to be a bug with float32 alltoall.
            # Do it in bfloat16 until the bug is fixed.
            # TODO(noam): file a bug
            t = tf.to_bfloat16(t)
        t = tpu_ops.all_to_all(t,
                               concat_dimension=concat_axis,
                               split_dimension=split_axis,
                               split_count=len(group_assignment[0]),
                               group_assignment=group_assignment)
        t = tf.cast(t, dtype)
        x = self.LaidOutTensor([t])
        return x
Example #2
0
def _to_bfloat16_unbiased(x, noise):
    """Convert a float32 to a bfloat16 using randomized roundoff.

  Args:
    x: A float32 Tensor.
    noise: a float32 Tensor with values in [0, 1), broadcastable to tf.shape(x)
  Returns:
    A float32 Tensor.
  """
    x_sign = tf.sign(x)
    # Make sure x is positive.  If it is zero, the two candidates are identical.
    x = x * x_sign + 1e-30
    cand1 = tf.to_bfloat16(x)
    cand1_f = tf.to_float(cand1)
    # This relies on the fact that for a positive bfloat16 b,
    # b * 1.005 gives you the next higher bfloat16 and b*0.995 gives you the
    # next lower one. Both 1.005 and 0.995 are ballpark estimation.
    cand2 = tf.to_bfloat16(
        tf.where(tf.greater(x, cand1_f), cand1_f * 1.005, cand1_f * 0.995))
    ret = _randomized_roundoff_to_bfloat16(x, noise, cand1, cand2)
    return ret * tf.to_bfloat16(x_sign)