Example #1
0
def _logical_to_physical_v1(sizes_and_strides,
                            physical_shape,
                            fn_1d=_logical_1d_to_physical_subspace_auto):
    """Maps logical m-dimensional mesh to physical n-dimensional mesh.

  Also see comments to _logical_1d_to_physical_subspace_auto.

  We are mapping a m-dimensonal logical mesh to a n-dimensional physical mesh.

  output[i] contains the coordinate-tuple in the physical mesh for the i-th
  logical processor (if the logical processors are ordered lexicographically).

  sizes_and_strides is a list of m lists of n (size, stride) pairs.

  sizes_and_strides[i] specifies the subspace (strided slice containing the
  origin) of the physical mesh covered by axis i of the logical mesh.  See
  comments to _logical_1d_to_physical_subspace_auto for more detail.

  For example, say we have a physical mesh with shape [4, 4, 2] and a logical
  mesh with shape [4, 8].  We want to divide the physical mesh into 4 tiles,
  each with shape [2, 2, 2].  The first logical dimension corresponds to which
  tile, and the second logical dimension corresponds to position within a tile.
  This would correspond to:
     physical_shape=[4, 4, 2]
     sizes_and_strides=[[(2, 2), (2, 2), (1, 2)], [(2, 1), (2, 1), (2, 1)]]

  physical_shape can be inferred from sizes_and_strides, but is passed in for
  error checking.

  Args:
    sizes_and_strides: a list of m list of n (size, stride) pairs
    physical_shape: a list of integers
    fn_1d: a function like _logical_1d_to_physical_subspace_auto
  Returns:
    a list of coordinate-lists
  """
    pndims = len(physical_shape)
    logical_shape = [
        mtf.list_product([p[0] for p in l]) for l in sizes_and_strides
    ]
    n = mtf.list_product(physical_shape)
    if n != mtf.list_product(logical_shape):
        raise ValueError("logical size and physical size must match "
                         "- got sizes_and_strides=%s physical_shape=%s" %
                         (sizes_and_strides, physical_shape))
    dimension_layouts = [fn_1d(l, physical_shape) for l in sizes_and_strides]
    tf.logging.info("physical_shape: %s" % physical_shape)
    tf.logging.info("sizes_and_strides: %s" % sizes_and_strides)
    for i, l in enumerate(dimension_layouts):
        tf.logging.info("dimension_layout %s: %s" % (i, l))
    ret = []
    for logical_pnum in range(n):
        logical_coordinates = mtf.pnum_to_processor_coordinates(
            logical_shape, logical_pnum)
        physical_coordinates = [0] * pndims
        for logical_axis, logical_coord in enumerate(logical_coordinates):
            for physical_axis in range(pndims):
                physical_coordinates[physical_axis] += (
                    dimension_layouts[logical_axis][logical_coord]
                    [physical_axis])
        ret.append(physical_coordinates)
    # verify that we have indeed covered all the processors
    l2p = [mtf.processor_coordinates_to_pnum(physical_shape, c) for c in ret]
    if sorted(l2p) != list(range(n)):
        raise ValueError(
            "logical_to_physical produced something that was not a permutation."
            " sizes_and_strides=%s physical_shape=%s ret=%s" %
            (sizes_and_strides, physical_shape, ret))
    return ret
Example #2
0
def auto_logical_to_physical_tpu(logical_shape,
                                 physical_shape,
                                 return_coordinates=False):
    """Set up a mapping from logical to physical cores for TPU.

  We will try to set up a mapping so that allreduce operations are relatively
  fast, prioritizing the later dimensions in the mesh_shape.

  Example:

  auto_logical_to_physical_tpu(
    logical_shape=[16, 8], physical_shape=[8, 8, 1, 2])

  Heuristics in this function subject to change.

  Args:
    logical_shape: a list of integers
    physical_shape: a list of integers - typically [X, Y, 1, cores]
    return_coordinates: a boolean - return a list of integer lists (coordinates)
       instead of a list of processor indices

  Returns:
    logical_to_physical: a permutation of range(product(physical_shape)))
  """
    tf.logging.info("auto_logical_to_physical_tpu "
                    "logical_shape=%s physical_shape=%s" %
                    (logical_shape, physical_shape))
    if mtf.list_product(logical_shape) != mtf.list_product(physical_shape):
        raise ValueError(
            "physical and logical shapes must have the same product "
            "physical_shape=%s logical_shape=%s" %
            (physical_shape, logical_shape))
    # drop logical dimensions of size 1
    logical_shape = [i for i in logical_shape if i != 1]
    num_cores = mtf.list_product(logical_shape)

    # For physical shapes different from what we are used to [2^a, 2^b, 2],
    #   return a simple default value (a lexicographic ordering)
    def _default_value():
        default = list(range(num_cores))
        if return_coordinates:
            default = [mtf.pnum_to_processor_coordinates(i) for i in default]
        return default

    if len(physical_shape) == 4 and physical_shape[2] == 1:
        physical_shape = physical_shape_3d_from_topology_proto_4d(
            physical_shape)
    elif len(physical_shape) != 3:
        tf.logging.warning("Unrecognized format for tpu physical shape")
        return _default_value()
    # physical_shape is a triple of rows, cols, cores
    p0, p1, p2 = physical_shape
    if p2 != 2:
        return _default_value
    for dimsize in [p0, p1]:
        # if dimsize not a power of 2, give up
        if dimsize & (dimsize - 1):
            return _default_value()
    # At this point, the physical shape has at least 1x1x2=2 cores, so there
    #   must be at least one logical dimension.
    assert logical_shape
    if len(logical_shape) == 1:
        # ring of p0 x p1 chips
        ring = _ring_2d(p0, p1)
        logical_to_physical = []
        for logical_pnum in range(num_cores):
            core_on_chip = logical_pnum % 2
            chip_num = logical_pnum // 2
            i, j = ring[chip_num]
            logical_to_physical.append((i, j, core_on_chip))
    else:
        # We have a p0 x p1 rectangle of chips, which we will tile with rectangular
        #   tiles.  The first logical dimension correspond to the number of tiles,
        #   and the other logical dimensions will correspond to position within a
        #   tile.
        num_tiles = logical_shape[0]
        tile_chips = num_cores // num_tiles // p2
        # If we can, we make each tile occupy exactly one row or column of chips.
        # Otherwise, we make each tile approximately square.
        if len(logical_shape) == 2 and tile_chips == p0:
            t0, t1 = [tile_chips, 1]
        elif len(logical_shape) == 2 and tile_chips == p1:
            t0, t1 = [1, tile_chips]
        else:
            # try to make the tile approximately square
            lg_tile_chips = int(math.log(tile_chips, 2))
            t0 = 2**(lg_tile_chips // 2)
            # make sure that the tile fits in the mesh - i.e.
            #   t0 <= p0
            #   t1 == tile_chips // t0 <= p1
            t0 = min(t0, p0)
            t0 = max(t0, tile_chips // p1)
            t1 = tile_chips // t0
        # recursive call to find mapping for one tile
        tile_logical_to_physical = auto_logical_to_physical_tpu(
            logical_shape[1:], [t0, t1, p2], return_coordinates=True)
        tiles_ring = _ring_2d(p0 // t0, p1 // t1)
        logical_to_physical = []
        for logical_pnum in range(num_cores):
            logical_tile_num = logical_pnum // (t0 * t1 * p2)
            logical_pos_in_tile = logical_pnum % (t0 * t1 * p2)
            logical_to_physical.append(
                (tiles_ring[logical_tile_num][0] * t0 +
                 tile_logical_to_physical[logical_pos_in_tile][0],
                 tiles_ring[logical_tile_num][1] * t1 +
                 tile_logical_to_physical[logical_pos_in_tile][1],
                 tile_logical_to_physical[logical_pos_in_tile][2]))
    tf.logging.info("auto_logical_to_physical_tpu logical_to_physical = %s" %
                    logical_to_physical)
    if return_coordinates:
        return logical_to_physical
    else:
        return [
            mtf.processor_coordinates_to_pnum(physical_shape, coord)
            for coord in logical_to_physical
        ]