Exemple #1
0
    def receive(self, x, mesh_axis, source_pcoord):
        """Collective receive in groups.

    Each group contains the processors that differ only in mesh_axis.

    ```python
    group_size = self.shape[mesh_axis].size
    ```

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer
      source_pcoord: a list of optional integers. Each element is either None
        or an integer in [0, group_size). If source_pcoord[k] is None, then the
        output for the k-th processor in each group is a zero tensor. If
        source_pcoord[k] is not None, then the output for the k-th processor in
        each group is equal to the input for the source_pcoord[k]-th processor
        in that group.

    Returns:
      a LaidOutTensor
    """
        x = x.to_laid_out_tensor()
        t = x.one_slice
        source_target_pairs = []

        for pnum in xrange(self.size):
            coord = mtf.pnum_to_processor_coordinates(self.shape, pnum)
            k = coord[mesh_axis]
            if source_pcoord[k] is not None:
                coord[mesh_axis] = source_pcoord[k]
                source_pnum = mtf.processor_coordinates_to_pnum(
                    self.shape, coord)
                source_target_pairs.append(
                    [self.l2p(source_pnum),
                     self.l2p(pnum)])

        if not source_target_pairs:
            ret = tf.zeros_like(t, t.dtype)
        elif t.dtype in [tf.float32, tf.bfloat16, tf.int32]:
            ret = tpu_ops.collective_permute(t, source_target_pairs)
        else:
            # If t is not one of the allowed types, cast and cast back.
            ret = tf.cast(
                tpu_ops.collective_permute(tf.cast(t, tf.float32),
                                           source_target_pairs), t.dtype)

        return self.LaidOutTensor([ret])
Exemple #2
0
  def receive(self, x, mesh_axis, source_pcoord):
    """Collective receive in groups.

    Each group contains the processors that differ only in mesh_axis.

    ```python
    group_size = self.shape[mesh_axis].size
    ```

    Args:
      x: a LaidOutTensor
      mesh_axis: an integer
      source_pcoord: a list of optional integers. Each element is either None
        or an integer in [0, group_size). If source_pcoord[k] is None, then the
        output for the k-th processor in each group is a zero tensor. If
        source_pcoord[k] is not None, then the output for the k-th processor in
        each group is equal to the input for the source_pcoord[k]-th processor
        in that group.

    Returns:
      a LaidOutTensor
    """
    x = x.to_laid_out_tensor()
    t = x.one_slice
    source_target_pairs = []

    for pnum in xrange(self.size):
      coord = self.pnum_to_processor_coordinates(self.shape, pnum)
      k = coord[mesh_axis]
      if source_pcoord[k] is not None:
        coord[mesh_axis] = source_pcoord[k]
        target_pnum = self.processor_coordinates_to_pnum(coord)
        source_target_pairs.append([pnum, target_pnum])

    return tpu_ops.collective_permute(t, source_target_pairs)