コード例 #1
0
def allreduce_ring(xs, devices, reduction_fn_string="SUM"):
    """Compute the reduction of all Tensors and put the result everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of strings
    reduction_fn_string: "SUM" or "MAX"

  Returns:
    a list of n Tensors
  Raises:
    ValueError: if devices is not a list of n strings
  """
    n = len(xs)
    if len(devices) != n:
        raise ValueError("devices must be a list of length len(xs)")
    if n == 1:
        return xs
    shape = xs[0].shape.as_list()
    # tf.logging.info("allreduce_ring shape = %s" % shape)
    size = None if None in shape else mtf.list_product(shape)
    if size is None or size < 1024 or size % n != 0:
        return allreduce_ring_single_shard(xs, devices, reduction_fn_string)

    def _circular_shift(l, n):
        n %= len(l)
        return l[-n:] + l[:-n]

    def _flatten_and_split(x):
        # tf.reshape treats [-1] as a special value denoting 1D flattening.
        return tf.split(tf.reshape(x, [-1]), n)

    def _concat_and_reshape(xs):
        return tf.reshape(tf.concat(xs, 0), shape)

    # [device, shard]
    x_split = mtf.parallel(devices, _flatten_and_split, xs)
    x_split_t = mtf.transpose_list_of_lists(x_split)

    y_split_t = []
    for shard in xrange(n):
        shard_xs = _circular_shift(x_split_t[shard], shard)
        shard_devices = _circular_shift(devices, shard)
        shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices,
                                               reduction_fn_string)
        y_split_t.append(_circular_shift(shard_ys, -shard))
    y_split = mtf.transpose_list_of_lists(y_split_t)
    ys = mtf.parallel(devices, _concat_and_reshape, y_split)
    return ys
コード例 #2
0
    def slicewise(self, fn, *inputs):
        """Execute a function in parallel on all slices.

    Args:
      fn: a function from tf.Tensors to tf.Tensor or a tuple of tf.Tensors.
      *inputs: a list of inputs.  Each input is either a LaidOutTensor or
        is convertible to a tf.Tensor.
    Returns:
      a LaidOutTensor, or a tuple of LaidOutTensors if fn returns a tuple.
    """
        if fn == tf.add:
            assert len(inputs) == 2
            if isinstance(inputs[0], mtf.LazyAllreduceSum):
                # sum of LazyAllreduceSum (keep delaying the allreduce)
                return inputs[0] + inputs[1]
        # convert all inputs to LaidOutTensor where possible
        inputs = mtf.convert_args_to_laid_out_tensors(inputs)
        inputs = [
            x.tensor_list if isinstance(x, self.LaidOutTensor) else [x] *
            len(self.devices) for x in inputs
        ]
        ret = mtf.parallel(self.devices, fn, *inputs)
        if isinstance(ret[0], tuple):
            ret = mtf.transpose_list_of_lists(ret)
            return tuple([self.LaidOutTensor(t) for t in ret])
        else:
            return self.LaidOutTensor(ret)
コード例 #3
0
def alltoall_pointtwise(xs, devices, split_axis, concat_axis):
    """MPI alltoall operation.

  Implementation of alltoall using pointwise communication.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of n strings
    split_axis: an integer
    concat_axis: an integer

  Returns:
    a list of n Tensors
  """
    n = len(xs)
    if n == 1:
        return xs
    # [target, source]
    parts = mtf.transpose_list_of_lists(
        mtf.parallel(devices, tf.split, xs, [n] * n, axis=[split_axis] * n))
    return mtf.parallel(devices, tf.concat, parts, axis=[concat_axis] * n)