Example #1
0
def _logical_1d_to_physical_subspace_auto(sizes_and_strides, physical_shape):
    """Maps logical 1d mesh to subspace of physical nd mesh.

  We are mapping a 1d logical mesh to a subspace (a strided slice containing the
  origin) of a n-dimensional physical mesh.

  output[i] contains the coordinate-tuple in the physical mesh for the i-th
  logical processor.

  sizes_and_strides is a list of (size, stride) pairs specifying the dimensions
  of the strided slice. For example,
    sizes_and_strides=[(2, 16), (4, 1)] would represent the slice containing
    [(0, 0), (0, 1), (0, 2), (0, 3),
     (16, 0), (16, 1), (16, 2), (16, 3)]

  This function heuristically picks an order, with the goal of optimizing
  allreduce performance.

  Args:
    sizes_and_strides: a list of n (size, stride) pairs
    physical_shape: ignored
  Returns:
    a list of coordinate-lists
  """
    del physical_shape
    ndims = len(sizes_and_strides)
    sizes = [p[0] for p in sizes_and_strides]
    strides = [p[1] for p in sizes_and_strides]
    n = mtf.list_product(sizes)
    if ndims >= 2 and sizes[0] > 1 and sizes[1] > 1:
        ring = _ring_2d(sizes[0], sizes[1])
        ret = []
        sizes_combined = [sizes[0] * sizes[1]] + sizes[2:]
        for logical_pnum in range(n):
            logical_coord = mtf.pnum_to_processor_coordinates(
                sizes_combined, logical_pnum)
            ret.append(list(ring[logical_coord[0]]) + logical_coord[1:])
    else:
        ret = [
            mtf.pnum_to_processor_coordinates(sizes, logical_pnum)
            for logical_pnum in range(n)
        ]
    # multiply by strides
    ret = [[x * stride for x, stride in zip(pcoord, strides)]
           for pcoord in ret]
    return ret
Example #2
0
    def receive(self, x, mesh_axis, source_pcoord):
        """Collective receive in groups.

    Each group contains the processors that differ only in mesh_axis.

    ```python
    group_size = self.shape[mesh_axis].size
    ```

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer
      source_pcoord: a list of optional integers. Each element is either None
        or an integer in [0, group_size). If source_pcoord[k] is None, then the
        output for the k-th processor in each group is a zero tensor. If
        source_pcoord[k] is not None, then the output for the k-th processor in
        each group is equal to the input for the source_pcoord[k]-th processor
        in that group.

    Returns:
      a LaidOutTensor
    """
        x = x.to_laid_out_tensor()
        t = x.one_slice
        source_target_pairs = []

        for pnum in xrange(self.size):
            coord = mtf.pnum_to_processor_coordinates(self.shape, pnum)
            k = coord[mesh_axis]
            if source_pcoord[k] is not None:
                coord[mesh_axis] = source_pcoord[k]
                source_pnum = mtf.processor_coordinates_to_pnum(
                    self.shape, coord)
                source_target_pairs.append(
                    [self.l2p(source_pnum),
                     self.l2p(pnum)])

        if not source_target_pairs:
            ret = tf.zeros_like(t, t.dtype)
        elif t.dtype in [tf.float32, tf.bfloat16, tf.int32]:
            ret = tpu_ops.collective_permute(t, source_target_pairs)
        else:
            # If t is not one of the allowed types, cast and cast back.
            ret = tf.cast(
                tpu_ops.collective_permute(tf.cast(t, tf.float32),
                                           source_target_pairs), t.dtype)

        return self.LaidOutTensor([ret])
Example #3
0
 def _default_value():
     default = list(range(num_cores))
     if return_coordinates:
         default = [mtf.pnum_to_processor_coordinates(i) for i in default]
     return default
Example #4
0
def _logical_to_physical_v1(sizes_and_strides,
                            physical_shape,
                            fn_1d=_logical_1d_to_physical_subspace_auto):
    """Maps logical m-dimensional mesh to physical n-dimensional mesh.

  Also see comments to _logical_1d_to_physical_subspace_auto.

  We are mapping a m-dimensonal logical mesh to a n-dimensional physical mesh.

  output[i] contains the coordinate-tuple in the physical mesh for the i-th
  logical processor (if the logical processors are ordered lexicographically).

  sizes_and_strides is a list of m lists of n (size, stride) pairs.

  sizes_and_strides[i] specifies the subspace (strided slice containing the
  origin) of the physical mesh covered by axis i of the logical mesh.  See
  comments to _logical_1d_to_physical_subspace_auto for more detail.

  For example, say we have a physical mesh with shape [4, 4, 2] and a logical
  mesh with shape [4, 8].  We want to divide the physical mesh into 4 tiles,
  each with shape [2, 2, 2].  The first logical dimension corresponds to which
  tile, and the second logical dimension corresponds to position within a tile.
  This would correspond to:
     physical_shape=[4, 4, 2]
     sizes_and_strides=[[(2, 2), (2, 2), (1, 2)], [(2, 1), (2, 1), (2, 1)]]

  physical_shape can be inferred from sizes_and_strides, but is passed in for
  error checking.

  Args:
    sizes_and_strides: a list of m list of n (size, stride) pairs
    physical_shape: a list of integers
    fn_1d: a function like _logical_1d_to_physical_subspace_auto
  Returns:
    a list of coordinate-lists
  """
    pndims = len(physical_shape)
    logical_shape = [
        mtf.list_product([p[0] for p in l]) for l in sizes_and_strides
    ]
    n = mtf.list_product(physical_shape)
    if n != mtf.list_product(logical_shape):
        raise ValueError("logical size and physical size must match "
                         "- got sizes_and_strides=%s physical_shape=%s" %
                         (sizes_and_strides, physical_shape))
    dimension_layouts = [fn_1d(l, physical_shape) for l in sizes_and_strides]
    tf.logging.info("physical_shape: %s" % physical_shape)
    tf.logging.info("sizes_and_strides: %s" % sizes_and_strides)
    for i, l in enumerate(dimension_layouts):
        tf.logging.info("dimension_layout %s: %s" % (i, l))
    ret = []
    for logical_pnum in range(n):
        logical_coordinates = mtf.pnum_to_processor_coordinates(
            logical_shape, logical_pnum)
        physical_coordinates = [0] * pndims
        for logical_axis, logical_coord in enumerate(logical_coordinates):
            for physical_axis in range(pndims):
                physical_coordinates[physical_axis] += (
                    dimension_layouts[logical_axis][logical_coord]
                    [physical_axis])
        ret.append(physical_coordinates)
    # verify that we have indeed covered all the processors
    l2p = [mtf.processor_coordinates_to_pnum(physical_shape, c) for c in ret]
    if sorted(l2p) != list(range(n)):
        raise ValueError(
            "logical_to_physical produced something that was not a permutation."
            " sizes_and_strides=%s physical_shape=%s ret=%s" %
            (sizes_and_strides, physical_shape, ret))
    return ret