def allreduce_ring(xs, devices, reduction_fn_string="SUM"):
    """Compute the reduction of all Tensors and put the result everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of strings
    reduction_fn_string: "SUM" or "MAX"

  Returns:
    a list of n Tensors
  Raises:
    ValueError: if devices is not a list of n strings
  """
    n = len(xs)
    if len(devices) != n:
        raise ValueError("devices must be a list of length len(xs)")
    if n == 1:
        return xs
    shape = xs[0].shape.as_list()
    # tf.logging.info("allreduce_ring shape = %s" % shape)
    size = None if None in shape else mtf.list_product(shape)
    if size is None or size < 1024 or size % n != 0:
        return allreduce_ring_single_shard(xs, devices, reduction_fn_string)

    def _circular_shift(l, n):
        n %= len(l)
        return l[-n:] + l[:-n]

    def _flatten_and_split(x):
        # tf.reshape treats [-1] as a special value denoting 1D flattening.
        return tf.split(tf.reshape(x, [-1]), n)

    def _concat_and_reshape(xs):
        return tf.reshape(tf.concat(xs, 0), shape)

    # [device, shard]
    x_split = mtf.parallel(devices, _flatten_and_split, xs)
    x_split_t = mtf.transpose_list_of_lists(x_split)

    y_split_t = []
    for shard in xrange(n):
        shard_xs = _circular_shift(x_split_t[shard], shard)
        shard_devices = _circular_shift(devices, shard)
        shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices,
                                               reduction_fn_string)
        y_split_t.append(_circular_shift(shard_ys, -shard))
    y_split = mtf.transpose_list_of_lists(y_split_t)
    ys = mtf.parallel(devices, _concat_and_reshape, y_split)
    return ys
예제 #2
0
        def assign_to_slices(self,
                             assign_fn,
                             values,
                             assign_to_tensor_list=None):
            """Assign to the slice variables.

      Args:
        assign_fn: a function from
          (mtf.Variable, tf.Variable, tf.Tensor) -> tf.Operation
        values: a list of tf.Tensor
        assign_to_tensor_list: an optional list of tf.Variable

      Returns:
        a tf.operation
      """
            if assign_to_tensor_list is None:
                assign_to_tensor_list = self._laid_out_tensor.all_slices
            # Handle both N -> 1 and N -> N cases.
            num_slices = min(len(assign_to_tensor_list), len(values))
            devices = [""] * num_slices
            return tf.group(
                mtf.parallel(devices, assign_fn,
                             [self._variable] * len(devices),
                             assign_to_tensor_list[:num_slices],
                             values[:num_slices]))
def allconcat_ring(xs, devices, concat_axis):
    """Concatenate all Tensors everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
    n = len(xs)
    if n == 1:
        return xs
    # [target, source]
    parts = [[
        xs[target] if target == source else None for source in xrange(n)
    ] for target in xrange(n)]
    for distance in xrange(1, n // 2 + 1):
        for target in xrange(n):
            source = (target + distance) % n
            if parts[target][source] is None:
                with tf.device(devices[target]):
                    parts[target][source] = tf.identity(parts[(target + 1) %
                                                              n][source])
            source = (target - distance) % n
            if parts[target][source] is None:
                with tf.device(devices[target]):
                    parts[target][source] = tf.identity(parts[(target - 1) %
                                                              n][source])
    return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
    def slicewise(self, fn, *inputs):
        """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
        if fn == tf.add:
            assert len(inputs) == 2
            if isinstance(inputs[0], mtf.LazyAllreduceSum):
                # sum of LazyAllreduceSum (keep delaying the allreduce)
                return inputs[0] + inputs[1]
        # convert all inputs to LaidOutTensor where possible
        inputs = mtf.convert_args_to_laid_out_tensors(inputs)
        inputs = [
            x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] *
            len(self.devices) for x in inputs
        ]
        ret = mtf.parallel(self.devices, fn, *inputs)
        if isinstance(ret[0], tuple):
            ret = mtf.transpose_list_of_lists(ret)
            return tuple([self.LaidOutTensor(t) for t in ret])
        else:
            return self.LaidOutTensor(ret)
예제 #5
0
  def to_tf_tensor(self, *inputs):

    inputs_save = inputs[1]
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]

    if inputs[1][0].name == "logits/add_1/parallel_0_1/Add:0":
      return inputs

    try:
      if inputs[0][0].shape[-1] != inputs[1][0].shape[-1]:
        tensor_axis = len(inputs[0][0].shape)-1
        ret = mtf.parallel(
            ["GPU:0"], tf.concat, [inputs[0]],
            axis=[tensor_axis] * len(["GPU:0"]))
        slices = ret
      else:
        return inputs_save

      for i in range(0,len(slices)):
        while slices[i].shape[0] == 3 and len(slices[i].shape) < 5:
          slices[i] = tf.reduce_mean(slices[i], axis=[0])
      return slices
    except:
      return inputs_save
    return inputs
예제 #6
0
def alltoall_ring(xs, devices, split_axis, concat_axis):
  """MPI alltoall operation.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    split_axis: an integer
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
  n = len(xs)
  if n == 1:
    return xs
  # set up
  # [target, source]
  parts = [[None] * n for i in xrange(n)]
  def my_split(x, size_splits):
    total_size = tf.shape(x)[split_axis]
    part_size = total_size // sum(size_splits)
    return tf.split(x, [s * part_size for s in size_splits], axis=split_axis)
  forward_message_size = (n - 1) // 2
  backward_message_size = (n - 1) - forward_message_size
  forward_messages = [None] * n
  backward_messages = [None] * n
  for i in xrange(n):
    with tf.device(devices[i]):
      if i >= backward_message_size:
        a, b, c, d = my_split(
            xs[i], [i - backward_message_size,
                    backward_message_size, 1, n - i - 1])
        backward_messages[i] = b
        parts[i][i] = c
        forward_messages[i] = tf.concat([d, a], axis=split_axis)
      else:
        a, b, c, d = my_split(
            xs[i], [i, 1, forward_message_size, backward_message_size - i])
        backward_messages[i] = tf.concat([d, a], axis=split_axis)
        parts[i][i] = b
        forward_messages[i] = c
  for step in xrange(1, max(forward_message_size, backward_message_size) + 1):
    new_forward_messages = [None] * n
    new_backward_messages = [None] * n
    for i in xrange(n):
      with tf.device(devices[i]):
        if forward_message_size > 0:
          parts[i][(i - step) % n], new_forward_messages[i] = my_split(
              forward_messages[(i - 1) % n], [1, forward_message_size - 1])
        if backward_message_size > 0:
          new_backward_messages[i], parts[i][(i + step) % n] = my_split(
              backward_messages[(i + 1) % n], [backward_message_size - 1, 1])
    forward_message_size -= 1
    backward_message_size -= 1
    forward_messages = new_forward_messages
    backward_messages = new_backward_messages
  return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
예제 #7
0
  def to_tf_tensor2(self, *inputs):
    inputs_save = inputs[1]
    inputs = mtf.convert_args_to_laid_out_tensors(inputs)
    inputs = [x.tensor_list if isinstance(x, self.LaidOutTensor)
              else [x] * len(self.devices) for x in inputs]
    tensor_axis = len(inputs[0][0].shape)-1
    ret = mtf.parallel(
        ["GPU:0"], tf.concat, [inputs[0]],
        axis=[tensor_axis] * len(["GPU:0"]))

    return ret
def alltoall_pointtwise(xs, devices, split_axis, concat_axis):
    """MPI alltoall operation.

  Implementation of alltoall using pointwise communication.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    split_axis: an integer
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
    n = len(xs)
    if n == 1:
        return xs
    # [target, source]
    parts = mtf.transpose_list_of_lists(
        mtf.parallel(devices, tf.split, xs, [n] * n, axis=[split_axis] * n))
    return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)
예제 #9
0
    def assign_to_slices(self, assign_fn, values):
      """Assign to the slice variables.

      Args:
        assign_fn: a function from
          (mtf.Variable, tf.Variable, tf.Tensor) -> tf.Operation
        values: a list of tf.Tensor

      Returns:
        a tf.operation
      """
      return tf.group(mtf.parallel(
          self._mesh_impl.devices, assign_fn, [self._variable] * len(values),
          self.laid_out_tensor.all_slices, values))