예제 #1
0
def allreduce_ring(xs, devices, reduction_fn_string="SUM"):
    """Compute the reduction of all Tensors and put the result everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of strings
    reduction_fn_string: "SUM" or "MAX"

  Returns:
    a list of n Tensors
  Raises:
    ValueError: if devices is not a list of n strings
  """
    n = len(xs)
    if len(devices) != n:
        raise ValueError("devices must be a list of length len(xs)")
    if n == 1:
        return xs
    shape = xs[0].shape.as_list()
    # tf.logging.info("allreduce_ring shape = %s" % shape)
    size = None if None in shape else mtf.list_product(shape)
    if size is None or size < 1024 or size % n != 0:
        return allreduce_ring_single_shard(xs, devices, reduction_fn_string)

    def _circular_shift(l, n):
        n %= len(l)
        return l[-n:] + l[:-n]

    def _flatten_and_split(x):
        return tf.split(tf.reshape(x, [size]), n)

    def _concat_and_reshape(xs):
        return tf.reshape(tf.concat(xs, 0), shape)

    # [device, shard]
    x_split = mtf.parallel(devices, _flatten_and_split, xs)
    x_split_t = mtf.transpose_list_of_lists(x_split)

    y_split_t = []
    for shard in xrange(n):
        shard_xs = _circular_shift(x_split_t[shard], shard)
        shard_devices = _circular_shift(devices, shard)
        shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices,
                                               reduction_fn_string)
        y_split_t.append(_circular_shift(shard_ys, -shard))
    y_split = mtf.transpose_list_of_lists(y_split_t)
    ys = mtf.parallel(devices, _concat_and_reshape, y_split)
    return ys
예제 #2
0
def allconcat_ring(xs, devices, concat_axis):
  """Concatenate all Tensors everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
  n = len(xs)
  if n == 1:
    return xs
  # [target, source]
  parts = [[xs[target] if target == source else None for source in xrange(n)]
           for target in xrange(n)]
  for distance in xrange(1, n // 2 + 1):
    for target in xrange(n):
      source = (target + distance) % n
      if parts[target][source] is None:
        with tf.device(devices[target]):
          parts[target][source] = tf.identity(parts[(target + 1) % n][source])
      source = (target - distance) % n
      if parts[target][source] is None:
        with tf.device(devices[target]):
          parts[target][source] = tf.identity(parts[(target - 1) % n][source])
  return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
예제 #3
0
  def slicewise(self, fn, *inputs):
    """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
    if fn == tf.add:
      assert len(inputs) == 2
      if isinstance(inputs[0], mtf.LazyAllreduceSum):
        # sum of LazyAllreduceSum (keep delaying the allreduce)
        return inputs[0] + inputs[1]
    # convert all inputs to LaidOutTensor where possible
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]
    ret = mtf.parallel(self.devices, fn, *inputs)
    if isinstance(ret[0], tuple):
      ret = mtf.transpose_list_of_lists(ret)
      return tuple([self.LaidOutTensor(t) for t in ret])
    else:
      return self.LaidOutTensor(ret)
예제 #4
0
def alltoall_ring(xs, devices, split_axis, concat_axis):
  """MPI alltoall operation.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    split_axis: an integer
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
  n = len(xs)
  if n == 1:
    return xs
  # set up
  # [target, source]
  parts = [[None] * n for i in xrange(n)]
  def my_split(x, size_splits):
    total_size = tf.shape(x)[split_axis]
    part_size = total_size // sum(size_splits)
    return tf.split(x, [s * part_size for s in size_splits], axis=split_axis)
  forward_message_size = (n - 1) // 2
  backward_message_size = (n - 1) - forward_message_size
  forward_messages = [None] * n
  backward_messages = [None] * n
  for i in xrange(n):
    with tf.device(devices[i]):
      if i >= backward_message_size:
        a, b, c, d = my_split(
            xs[i], [i - backward_message_size,
                    backward_message_size, 1, n - i - 1])
        backward_messages[i] = b
        parts[i][i] = c
        forward_messages[i] = tf.concat([d, a], axis=split_axis)
      else:
        a, b, c, d = my_split(
            xs[i], [i, 1, forward_message_size, backward_message_size - i])
        backward_messages[i] = tf.concat([d, a], axis=split_axis)
        parts[i][i] = b
        forward_messages[i] = c
  for step in xrange(1, max(forward_message_size, backward_message_size) + 1):
    new_forward_messages = [None] * n
    new_backward_messages = [None] * n
    for i in xrange(n):
      with tf.device(devices[i]):
        if forward_message_size > 0:
          parts[i][(i - step) % n], new_forward_messages[i] = my_split(
              forward_messages[(i - 1) % n], [1, forward_message_size - 1])
        if backward_message_size > 0:
          new_backward_messages[i], parts[i][(i + step) % n] = my_split(
              backward_messages[(i + 1) % n], [backward_message_size - 1, 1])
    forward_message_size -= 1
    backward_message_size -= 1
    forward_messages = new_forward_messages
    backward_messages = new_backward_messages
  return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
예제 #5
0
    def assign_to_slices(self, slices):
      """Assign to the slice variables.

      Args:
        slices: a list of tf.Tensor

      Returns:
        a tf.operation
      """
      return tf.group(mtf.parallel(
          self._mesh_impl.devices, tf.assign,
          self.laid_out_tensor.all_slices, slices))
예제 #6
0
def alltoall_pointtwise(xs, devices, split_axis, concat_axis):
  """MPI alltoall operation.

  Implementation of alltoall using pointwise communication.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    split_axis: an integer
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
  n = len(xs)
  if n == 1:
    return xs
  # [target, source]
  parts = mtf.transpose_list_of_lists(
      mtf.parallel(devices, tf.split, xs, [n] * n, axis=[split_axis] * n))
  return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
예제 #7
0
    def assign_to_slices(self, slice_values, assign_to_tensor_list=None):
      """Assign to the slice variables.

      Args:
        slice_values: a list of tf.Tensor
        assign_to_tensor_list: an optional list of tf.Variable

      Returns:
        a tf.operation
      """
      if assign_to_tensor_list is None:
        assign_to_tensor_list = self._laid_out_tensor.all_slices
      # Handle both N -> 1 and N -> N cases.
      num_slices = min(
          len(assign_to_tensor_list), len(slice_values))
      devices = [""] * num_slices
      return tf.group(
          mtf.parallel(devices, tf.assign, assign_to_tensor_list[:num_slices],
                       slice_values[:num_slices]))