Esempio n. 1
0
def dense(x,
          output_dim,
          reduced_dims=None,
          expert_dims=None,
          use_bias=True,
          activation=None,
          master_dtype=tf.float32,
          slice_dtype=tf.float32,
          name=None):
    """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    master_dtype: a tf.dtype
    slice_dtype: a tf.dtype
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
    if expert_dims is None:
        expert_dims = []
    if reduced_dims is None:
        reduced_dims = x.shape.dims[-1:]
    w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim])
    output_shape = mtf.Shape(
        [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])

    with tf.variable_scope(name, default_name="dense"):
        stddev = mtf.list_product(d.size for d in reduced_dims)**-0.5
        w = mtf.get_variable(
            x.mesh,
            "kernel",
            w_shape,
            initializer=tf.random_normal_initializer(stddev=stddev),
            master_dtype=master_dtype,
            slice_dtype=slice_dtype,
            activation_dtype=x.dtype)
        y = mtf.einsum([x, w], output_shape)
        if use_bias:
            b = mtf.get_variable(x.mesh,
                                 "bias",
                                 mtf.Shape(expert_dims + [output_dim]),
                                 initializer=tf.zeros_initializer(),
                                 activation_dtype=x.dtype)
            y += b
        if activation is not None:
            y = activation(y)
        return y
Esempio n. 2
0
def allreduce_ring(xs, devices, reduction_fn_string="SUM"):
    """Compute the reduction of all Tensors and put the result everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of strings
    reduction_fn_string: "SUM" or "MAX"

  Returns:
    a list of n Tensors
  Raises:
    ValueError: if devices is not a list of n strings
  """
    n = len(xs)
    if len(devices) != n:
        raise ValueError("devices must be a list of length len(xs)")
    if n == 1:
        return xs
    shape = xs[0].shape.as_list()
    # tf.logging.info("allreduce_ring shape = %s" % shape)
    size = None if None in shape else mtf.list_product(shape)
    if size is None or size < 1024 or size % n != 0:
        return allreduce_ring_single_shard(xs, devices, reduction_fn_string)

    def _circular_shift(l, n):
        n %= len(l)
        return l[-n:] + l[:-n]

    def _flatten_and_split(x):
        return tf.split(tf.reshape(x, [size]), n)

    def _concat_and_reshape(xs):
        return tf.reshape(tf.concat(xs, 0), shape)

    # [device, shard]
    x_split = mtf.parallel(devices, _flatten_and_split, xs)
    x_split_t = mtf.transpose_list_of_lists(x_split)

    y_split_t = []
    for shard in xrange(n):
        shard_xs = _circular_shift(x_split_t[shard], shard)
        shard_devices = _circular_shift(devices, shard)
        shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices,
                                               reduction_fn_string)
        y_split_t.append(_circular_shift(shard_ys, -shard))
    y_split = mtf.transpose_list_of_lists(y_split_t)
    ys = mtf.parallel(devices, _concat_and_reshape, y_split)
    return ys