def _logical_1d_to_physical_subspace_auto(sizes_and_strides, physical_shape): """Maps logical 1d mesh to subspace of physical nd mesh. We are mapping a 1d logical mesh to a subspace (a strided slice containing the origin) of a n-dimensional physical mesh. output[i] contains the coordinate-tuple in the physical mesh for the i-th logical processor. sizes_and_strides is a list of (size, stride) pairs specifying the dimensions of the strided slice. For example, sizes_and_strides=[(2, 16), (4, 1)] would represent the slice containing [(0, 0), (0, 1), (0, 2), (0, 3), (16, 0), (16, 1), (16, 2), (16, 3)] This function heuristically picks an order, with the goal of optimizing allreduce performance. Args: sizes_and_strides: a list of n (size, stride) pairs physical_shape: ignored Returns: a list of coordinate-lists """ del physical_shape ndims = len(sizes_and_strides) sizes = [p[0] for p in sizes_and_strides] strides = [p[1] for p in sizes_and_strides] n = mtf.list_product(sizes) if ndims >= 2 and sizes[0] > 1 and sizes[1] > 1: ring = _ring_2d(sizes[0], sizes[1]) ret = [] sizes_combined = [sizes[0] * sizes[1]] + sizes[2:] for logical_pnum in range(n): logical_coord = mtf.pnum_to_processor_coordinates( sizes_combined, logical_pnum) ret.append(list(ring[logical_coord[0]]) + logical_coord[1:]) else: ret = [ mtf.pnum_to_processor_coordinates(sizes, logical_pnum) for logical_pnum in range(n) ] # multiply by strides ret = [[x * stride for x, stride in zip(pcoord, strides)] for pcoord in ret] return ret
def receive(self, x, mesh_axis, source_pcoord): """Collective receive in groups. Each group contains the processors that differ only in mesh_axis. ```python group_size = self.shape[mesh_axis].size ``` Args: x: a LaidOutTensor mesh_axis: an integer source_pcoord: a list of optional integers. Each element is either None or an integer in [0, group_size). If source_pcoord[k] is None, then the output for the k-th processor in each group is a zero tensor. If source_pcoord[k] is not None, then the output for the k-th processor in each group is equal to the input for the source_pcoord[k]-th processor in that group. Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice source_target_pairs = [] for pnum in xrange(self.size): coord = mtf.pnum_to_processor_coordinates(self.shape, pnum) k = coord[mesh_axis] if source_pcoord[k] is not None: coord[mesh_axis] = source_pcoord[k] source_pnum = mtf.processor_coordinates_to_pnum( self.shape, coord) source_target_pairs.append( [self.l2p(source_pnum), self.l2p(pnum)]) if not source_target_pairs: ret = tf.zeros_like(t, t.dtype) elif t.dtype in [tf.float32, tf.bfloat16, tf.int32]: ret = tpu_ops.collective_permute(t, source_target_pairs) else: # If t is not one of the allowed types, cast and cast back. ret = tf.cast( tpu_ops.collective_permute(tf.cast(t, tf.float32), source_target_pairs), t.dtype) return self.LaidOutTensor([ret])
def _default_value(): default = list(range(num_cores)) if return_coordinates: default = [mtf.pnum_to_processor_coordinates(i) for i in default] return default
def _logical_to_physical_v1(sizes_and_strides, physical_shape, fn_1d=_logical_1d_to_physical_subspace_auto): """Maps logical m-dimensional mesh to physical n-dimensional mesh. Also see comments to _logical_1d_to_physical_subspace_auto. We are mapping a m-dimensonal logical mesh to a n-dimensional physical mesh. output[i] contains the coordinate-tuple in the physical mesh for the i-th logical processor (if the logical processors are ordered lexicographically). sizes_and_strides is a list of m lists of n (size, stride) pairs. sizes_and_strides[i] specifies the subspace (strided slice containing the origin) of the physical mesh covered by axis i of the logical mesh. See comments to _logical_1d_to_physical_subspace_auto for more detail. For example, say we have a physical mesh with shape [4, 4, 2] and a logical mesh with shape [4, 8]. We want to divide the physical mesh into 4 tiles, each with shape [2, 2, 2]. The first logical dimension corresponds to which tile, and the second logical dimension corresponds to position within a tile. This would correspond to: physical_shape=[4, 4, 2] sizes_and_strides=[[(2, 2), (2, 2), (1, 2)], [(2, 1), (2, 1), (2, 1)]] physical_shape can be inferred from sizes_and_strides, but is passed in for error checking. Args: sizes_and_strides: a list of m list of n (size, stride) pairs physical_shape: a list of integers fn_1d: a function like _logical_1d_to_physical_subspace_auto Returns: a list of coordinate-lists """ pndims = len(physical_shape) logical_shape = [ mtf.list_product([p[0] for p in l]) for l in sizes_and_strides ] n = mtf.list_product(physical_shape) if n != mtf.list_product(logical_shape): raise ValueError("logical size and physical size must match " "- got sizes_and_strides=%s physical_shape=%s" % (sizes_and_strides, physical_shape)) dimension_layouts = [fn_1d(l, physical_shape) for l in sizes_and_strides] tf.logging.info("physical_shape: %s" % physical_shape) tf.logging.info("sizes_and_strides: %s" % sizes_and_strides) for i, l in enumerate(dimension_layouts): tf.logging.info("dimension_layout %s: %s" % (i, l)) ret = [] for logical_pnum in range(n): logical_coordinates = mtf.pnum_to_processor_coordinates( logical_shape, logical_pnum) physical_coordinates = [0] * pndims for logical_axis, logical_coord in enumerate(logical_coordinates): for physical_axis in range(pndims): physical_coordinates[physical_axis] += ( dimension_layouts[logical_axis][logical_coord] [physical_axis]) ret.append(physical_coordinates) # verify that we have indeed covered all the processors l2p = [mtf.processor_coordinates_to_pnum(physical_shape, c) for c in ret] if sorted(l2p) != list(range(n)): raise ValueError( "logical_to_physical produced something that was not a permutation." " sizes_and_strides=%s physical_shape=%s ret=%s" % (sizes_and_strides, physical_shape, ret)) return ret