示例#1
0
文件: layers.py 项目: qixiuai/mesh
def dense(x, output_dim, reduced_dims=None, expert_dims=None,
          use_bias=True, activation=None,
          master_dtype=tf.float32,
          slice_dtype=tf.float32,
          variable_dtype=None,
          name=None):
  """Dense layer doing (kernel*x + bias) computation.

  Args:
    x: a mtf.Tensor of shape [..., reduced_dims].
    output_dim: a mtf.Dimension
    reduced_dims: an optional list of mtf.Dimensions of x to be reduced. If
      omitted, we reduce the last dimension.
    expert_dims: an optional list of mtf.Dimension which represent different
      experts. Different experts get different weights.
    use_bias: a boolean, whether to add bias.
    activation: an optional function from mtf.Tensor to mtf.Tensor
    master_dtype: a tf.dtype (deprecated - use variable_dtype)
    slice_dtype: a tf.dtype (deprecated - use variable_dtype)
    variable_dtype: a mtf.VariableDType
    name: a string. variable scope.

  Returns:
    a mtf.Tensor of shape [..., output_dim].
  """
  if variable_dtype is None:
    variable_dtype = mtf.VariableDType(master_dtype, slice_dtype, x.dtype)
  if expert_dims is None:
    expert_dims = []
  if reduced_dims is None:
    reduced_dims = x.shape.dims[-1:]
  w_shape = mtf.Shape(expert_dims + reduced_dims + [output_dim])
  output_shape = mtf.Shape(
      [d for d in x.shape.dims if d not in reduced_dims] + [output_dim])

  with tf.variable_scope(name, default_name="dense"):
    stddev = mtf.list_product(d.size for d in reduced_dims) ** -0.5
    w = mtf.get_variable(
        x.mesh,
        "kernel",
        w_shape,
        initializer=tf.random_normal_initializer(stddev=stddev),
        dtype=variable_dtype)
    w = mtf.cast(w, x.dtype)
    y = mtf.einsum([x, w], output_shape)
    if use_bias:
      b = mtf.get_variable(
          x.mesh,
          "bias",
          mtf.Shape(expert_dims + [output_dim]),
          initializer=tf.zeros_initializer(),
          dtype=variable_dtype)
      y += b
    if activation is not None:
      y = activation(y)
    return y
def allreduce_ring(xs, devices, reduction_fn_string="SUM"):
    """Compute the reduction of all Tensors and put the result everywhere.

  Performance-optimized for a ring of devices.

  Args:
    xs: a list of n tf.Tensors
    devices: a list of strings
    reduction_fn_string: "SUM" or "MAX"

  Returns:
    a list of n Tensors
  Raises:
    ValueError: if devices is not a list of n strings
  """
    n = len(xs)
    if len(devices) != n:
        raise ValueError("devices must be a list of length len(xs)")
    if n == 1:
        return xs
    shape = xs[0].shape.as_list()
    # tf.logging.info("allreduce_ring shape = %s" % shape)
    size = None if None in shape else mtf.list_product(shape)
    if size is None or size < 1024 or size % n != 0:
        return allreduce_ring_single_shard(xs, devices, reduction_fn_string)

    def _circular_shift(l, n):
        n %= len(l)
        return l[-n:] + l[:-n]

    def _flatten_and_split(x):
        # tf.reshape treats [-1] as a special value denoting 1D flattening.
        return tf.split(tf.reshape(x, [-1]), n)

    def _concat_and_reshape(xs):
        return tf.reshape(tf.concat(xs, 0), shape)

    # [device, shard]
    x_split = mtf.parallel(devices, _flatten_and_split, xs)
    x_split_t = mtf.transpose_list_of_lists(x_split)

    y_split_t = []
    for shard in xrange(n):
        shard_xs = _circular_shift(x_split_t[shard], shard)
        shard_devices = _circular_shift(devices, shard)
        shard_ys = allreduce_ring_single_shard(shard_xs, shard_devices,
                                               reduction_fn_string)
        y_split_t.append(_circular_shift(shard_ys, -shard))
    y_split = mtf.transpose_list_of_lists(y_split_t)
    ys = mtf.parallel(devices, _concat_and_reshape, y_split)
    return ys
示例#3
0
    def __init__(self, spec, physical_shape):
        """Constructs a HierarchicalTiling.

    spec is a list corresponding to the logical dimensions.

    spec[i] corresponds to the i-th logical dimension and consists of a name
      and a list of integers, the list being the shape of logical axis i when
      it is physically projected to the physical mesh and then compacted.

    Striding information is omitted.  By convention, the earlier dimensions
      get more strided. so the axis corresponding to the last dimension always
      gets projected to the tile specified by its shape.

    Args:
      spec: a list of (string, list-of-integers) pairs
      physical_shape: a list of integers
    """
        self._names = [p[0] for p in spec]
        logical_ndims = len(spec)
        physical_ndims = len(physical_shape)
        projected_shapes = [p[1] for p in spec]
        if logical_ndims > 0 and projected_shapes[0] is None:
            # fill in missing value
            projected_shapes[0] = list(physical_shape)
            for s in projected_shapes[1:]:
                for i, x in enumerate(s):
                    projected_shapes[0][i] //= x
        # compute strides, and verify that the spec is valid.
        products = [1] * physical_ndims
        sizes_and_strides = []
        for s in reversed(projected_shapes):
            sizes_and_strides.append([(size, stride)
                                      for size, stride in zip(s, products)])
            for i, x in enumerate(s):
                products[i] *= x
        if products != physical_shape:
            raise ValueError("mesh spec multiplies to the wrong size"
                             "spec=%s physical_shape=%s products=%s" %
                             (spec, physical_shape, products))
        sizes_and_strides.reverse()
        self._physical_coordinates = _logical_to_physical_v1(
            sizes_and_strides, physical_shape)
        self._logical_to_physical = [
            mtf.processor_coordinates_to_pnum(physical_shape, c)
            for c in self._physical_coordinates
        ]
        self._mesh_shape = mtf.Shape([
            mtf.Dimension(name, mtf.list_product(s))
            for name, s in zip(self._names, projected_shapes)
        ])
示例#4
0
    def spec_to_mesh_shape(cls, spec, num_processors):
        """Compute mesh shape even without knowing the physical shape.

    This is useful in cases where the mesh shape must be computed before
    you know the physical_shape.

    Args:
      spec: a list of (string, list-of-integers) pairs
      num_processors: an integer
    Returns:
      a mtf.Shape
    """
        logical_ndims = len(spec)
        names = [p[0] for p in spec]
        sizes = [p[1] for p in spec]
        sizes = [None if s is None else mtf.list_product(s) for s in sizes]
        if logical_ndims > 0 and sizes[0] is None:
            sizes[0] = num_processors // mtf.list_product(sizes[1:])
        if mtf.list_product(sizes) != num_processors:
            raise ValueError("product of spec must be num_processors"
                             " spec=%s num_processors=%s" %
                             (spec, num_processors))
        return mtf.Shape(
            [mtf.Dimension(name, s) for name, s in zip(names, sizes)])
示例#5
0
def _logical_1d_to_physical_subspace_auto(sizes_and_strides, physical_shape):
    """Maps logical 1d mesh to subspace of physical nd mesh.

  We are mapping a 1d logical mesh to a subspace (a strided slice containing the
  origin) of a n-dimensional physical mesh.

  output[i] contains the coordinate-tuple in the physical mesh for the i-th
  logical processor.

  sizes_and_strides is a list of (size, stride) pairs specifying the dimensions
  of the strided slice. For example,
    sizes_and_strides=[(2, 16), (4, 1)] would represent the slice containing
    [(0, 0), (0, 1), (0, 2), (0, 3),
     (16, 0), (16, 1), (16, 2), (16, 3)]

  This function heuristically picks an order, with the goal of optimizing
  allreduce performance.

  Args:
    sizes_and_strides: a list of n (size, stride) pairs
    physical_shape: ignored
  Returns:
    a list of coordinate-lists
  """
    del physical_shape
    ndims = len(sizes_and_strides)
    sizes = [p[0] for p in sizes_and_strides]
    strides = [p[1] for p in sizes_and_strides]
    n = mtf.list_product(sizes)
    if ndims >= 2 and sizes[0] > 1 and sizes[1] > 1:
        ring = _ring_2d(sizes[0], sizes[1])
        ret = []
        sizes_combined = [sizes[0] * sizes[1]] + sizes[2:]
        for logical_pnum in range(n):
            logical_coord = mtf.pnum_to_processor_coordinates(
                sizes_combined, logical_pnum)
            ret.append(list(ring[logical_coord[0]]) + logical_coord[1:])
    else:
        ret = [
            mtf.pnum_to_processor_coordinates(sizes, logical_pnum)
            for logical_pnum in range(n)
        ]
    # multiply by strides
    ret = [[x * stride for x, stride in zip(pcoord, strides)]
           for pcoord in ret]
    return ret
示例#6
0
def auto_logical_to_physical_tpu(logical_shape,
                                 physical_shape,
                                 return_coordinates=False):
    """Set up a mapping from logical to physical cores for TPU.

  We will try to set up a mapping so that allreduce operations are relatively
  fast, prioritizing the later dimensions in the mesh_shape.

  Example:

  auto_logical_to_physical_tpu(
    logical_shape=[16, 8], physical_shape=[8, 8, 1, 2])

  Heuristics in this function subject to change.

  Args:
    logical_shape: a list of integers
    physical_shape: a list of integers - typically [X, Y, 1, cores]
    return_coordinates: a boolean - return a list of integer lists (coordinates)
       instead of a list of processor indices

  Returns:
    logical_to_physical: a permutation of range(product(physical_shape)))
  """
    tf.logging.info("auto_logical_to_physical_tpu "
                    "logical_shape=%s physical_shape=%s" %
                    (logical_shape, physical_shape))
    if mtf.list_product(logical_shape) != mtf.list_product(physical_shape):
        raise ValueError(
            "physical and logical shapes must have the same product "
            "physical_shape=%s logical_shape=%s" %
            (physical_shape, logical_shape))
    # drop logical dimensions of size 1
    logical_shape = [i for i in logical_shape if i != 1]
    num_cores = mtf.list_product(logical_shape)

    # For physical shapes different from what we are used to [2^a, 2^b, 2],
    #   return a simple default value (a lexicographic ordering)
    def _default_value():
        default = list(range(num_cores))
        if return_coordinates:
            default = [mtf.pnum_to_processor_coordinates(i) for i in default]
        return default

    if len(physical_shape) == 4 and physical_shape[2] == 1:
        # This is what we currently expect as the TPU physical shape.
        # The first two dimensions (in backwards order) correspond to the chip
        #   number and the last dimension corresponds to two cores on a chip.
        physical_shape = [
            physical_shape[1], physical_shape[0], physical_shape[3]
        ]
    else:
        tf.logging.warning("Unrecognized format for tpu physical shape")
    if len(physical_shape) != 3:
        return _default_value()
    p0, p1, p2 = physical_shape
    if p2 != 2:
        return _default_value
    for dimsize in [p0, p1]:
        # if dimsize not a power of 2, give up
        if dimsize & (dimsize - 1):
            return _default_value()
    # At this point, the physical shape has at least 1x1x2=2 cores, so there
    #   must be at least one logical dimension.
    assert logical_shape
    if len(logical_shape) == 1:
        # ring of p0 x p1 chips
        ring = _ring_2d(p0, p1)
        logical_to_physical = []
        for logical_pnum in range(num_cores):
            # Go through all chips using core 0, then go through all chips
            #   backwards using core 1.  This is better in the case where
            #   one of the tile dimensions is 1, so the last chip is not adjacent
            #   to the first chip.
            core_on_chip = logical_pnum // (p0 * p1)
            if core_on_chip == 0:
                chip_num = logical_pnum
            else:
                chip_num = num_cores - 1 - logical_pnum
            i, j = ring[chip_num]
            logical_to_physical.append((i, j, core_on_chip))
    else:
        # We have a p0 x p1 rectangle of chips, which we will tile with rectangular
        #   tiles.  The first logical dimension correspond to the number of tiles,
        #   and the other logical dimensions will correspond to position within a
        #   tile.
        num_tiles = logical_shape[0]
        tile_chips = num_cores // num_tiles // p2
        # If we can, we make each tile occupy exactly one row or column of chips.
        # Otherwise, we make each tile approximately square.
        if len(logical_shape) == 2 and tile_chips == p0:
            t0, t1 = [tile_chips, 1]
        elif len(logical_shape) == 2 and tile_chips == p1:
            t0, t1 = [1, tile_chips]
        else:
            # try to make the tile approximately square
            lg_tile_chips = int(math.log(tile_chips, 2))
            t0 = 2**(lg_tile_chips // 2)
            # make sure that the tile fits in the mesh - i.e.
            #   t0 <= p0
            #   t1 == tile_chips // t0 <= p1
            t0 = min(t0, p0)
            t0 = max(t0, tile_chips // p1)
            t1 = tile_chips // t0
        # recursive call to find mapping for one tile
        tile_logical_to_physical = auto_logical_to_physical_tpu(
            logical_shape[1:], [t0, t1, p2], return_coordinates=True)
        tiles_ring = _ring_2d(p0 // t0, p1 // t1)
        logical_to_physical = []
        for logical_pnum in range(num_cores):
            logical_tile_num = logical_pnum // (t0 * t1 * p2)
            logical_pos_in_tile = logical_pnum % (t0 * t1 * p2)
            logical_to_physical.append(
                (tiles_ring[logical_tile_num][0] * t0 +
                 tile_logical_to_physical[logical_pos_in_tile][0],
                 tiles_ring[logical_tile_num][1] * t1 +
                 tile_logical_to_physical[logical_pos_in_tile][1],
                 tile_logical_to_physical[logical_pos_in_tile][2]))
    if return_coordinates:
        return logical_to_physical
    else:
        return [
            mtf.processor_coordinates_to_pnum(physical_shape, coord)
            for coord in logical_to_physical
        ]
示例#7
0
def _logical_to_physical_v1(sizes_and_strides,
                            physical_shape,
                            fn_1d=_logical_1d_to_physical_subspace_auto):
    """Maps logical m-dimensional mesh to physical n-dimensional mesh.

  Also see comments to _logical_1d_to_physical_subspace_auto.

  We are mapping a m-dimensonal logical mesh to a n-dimensional physical mesh.

  output[i] contains the coordinate-tuple in the physical mesh for the i-th
  logical processor (if the logical processors are ordered lexicographically).

  sizes_and_strides is a list of m lists of n (size, stride) pairs.

  sizes_and_strides[i] specifies the subspace (strided slice containing the
  origin) of the physical mesh covered by axis i of the logical mesh.  See
  comments to _logical_1d_to_physical_subspace_auto for more detail.

  For example, say we have a physical mesh with shape [4, 4, 2] and a logical
  mesh with shape [4, 8].  We want to divide the physical mesh into 4 tiles,
  each with shape [2, 2, 2].  The first logical dimension corresponds to which
  tile, and the second logical dimension corresponds to position within a tile.
  This would correspond to:
     physical_shape=[4, 4, 2]
     sizes_and_strides=[[(2, 2), (2, 2), (1, 2)], [(2, 1), (2, 1), (2, 1)]]

  physical_shape can be inferred from sizes_and_strides, but is passed in for
  error checking.

  Args:
    sizes_and_strides: a list of m list of n (size, stride) pairs
    physical_shape: a list of integers
    fn_1d: a function like _logical_1d_to_physical_subspace_auto
  Returns:
    a list of coordinate-lists
  """
    pndims = len(physical_shape)
    logical_shape = [
        mtf.list_product([p[0] for p in l]) for l in sizes_and_strides
    ]
    n = mtf.list_product(physical_shape)
    if n != mtf.list_product(logical_shape):
        raise ValueError("logical size and physical size must match "
                         "- got sizes_and_strides=%s physical_shape=%s" %
                         (sizes_and_strides, physical_shape))
    dimension_layouts = [fn_1d(l, physical_shape) for l in sizes_and_strides]
    tf.logging.info("physical_shape: %s" % physical_shape)
    tf.logging.info("sizes_and_strides: %s" % sizes_and_strides)
    for i, l in enumerate(dimension_layouts):
        tf.logging.info("dimension_layout %s: %s" % (i, l))
    ret = []
    for logical_pnum in range(n):
        logical_coordinates = mtf.pnum_to_processor_coordinates(
            logical_shape, logical_pnum)
        physical_coordinates = [0] * pndims
        for logical_axis, logical_coord in enumerate(logical_coordinates):
            for physical_axis in range(pndims):
                physical_coordinates[physical_axis] += (
                    dimension_layouts[logical_axis][logical_coord]
                    [physical_axis])
        ret.append(physical_coordinates)
    # verify that we have indeed covered all the processors
    l2p = [mtf.processor_coordinates_to_pnum(physical_shape, c) for c in ret]
    if sorted(l2p) != list(range(n)):
        raise ValueError(
            "logical_to_physical produced something that was not a permutation."
            " sizes_and_strides=%s physical_shape=%s ret=%s" %
            (sizes_and_strides, physical_shape, ret))
    return ret