def receive(self, x, mesh_axis, source_pcoord): """Collective receive in groups. Each group contains the processors that differ only in mesh_axis. ```python group_size = self.shape[mesh_axis].size ``` Args: x: a LaidOutTensor mesh_axis: an integer source_pcoord: a list of optional integers. Each element is either None or an integer in [0, group_size). If source_pcoord[k] is None, then the output for the k-th processor in each group is a zero tensor. If source_pcoord[k] is not None, then the output for the k-th processor in each group is equal to the input for the source_pcoord[k]-th processor in that group. Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice source_target_pairs = [] for pnum in xrange(self.size): coord = mtf.pnum_to_processor_coordinates(self.shape, pnum) k = coord[mesh_axis] if source_pcoord[k] is not None: coord[mesh_axis] = source_pcoord[k] source_pnum = mtf.processor_coordinates_to_pnum( self.shape, coord) source_target_pairs.append( [self.l2p(source_pnum), self.l2p(pnum)]) if not source_target_pairs: ret = tf.zeros_like(t, t.dtype) elif t.dtype in [tf.float32, tf.bfloat16, tf.int32]: ret = tpu_ops.collective_permute(t, source_target_pairs) else: # If t is not one of the allowed types, cast and cast back. ret = tf.cast( tpu_ops.collective_permute(tf.cast(t, tf.float32), source_target_pairs), t.dtype) return self.LaidOutTensor([ret])
def receive(self, x, mesh_axis, source_pcoord): """Collective receive in groups. Each group contains the processors that differ only in mesh_axis. ```python group_size = self.shape[mesh_axis].size ``` Args: x: a LaidOutTensor mesh_axis: an integer source_pcoord: a list of optional integers. Each element is either None or an integer in [0, group_size). If source_pcoord[k] is None, then the output for the k-th processor in each group is a zero tensor. If source_pcoord[k] is not None, then the output for the k-th processor in each group is equal to the input for the source_pcoord[k]-th processor in that group. Returns: a LaidOutTensor """ x = x.to_laid_out_tensor() t = x.one_slice source_target_pairs = [] for pnum in xrange(self.size): coord = self.pnum_to_processor_coordinates(self.shape, pnum) k = coord[mesh_axis] if source_pcoord[k] is not None: coord[mesh_axis] = source_pcoord[k] target_pnum = self.processor_coordinates_to_pnum(coord) source_target_pairs.append([pnum, target_pnum]) return tpu_ops.collective_permute(t, source_target_pairs)