def apply_to_tensor(self,
                        tensor,
                        assign_tuple_sharding=False,
                        use_sharding_op=False):
        """Applies this Sharding attribute to `tensor`.

    Args:
      tensor: A tf.Tensor to split.
      assign_tuple_sharding: If the sharding type should be a tuple.
      use_sharding_op: whether to create a sharding op on `tensor`.

    Returns:
      The tensor with Sharding attribute.
    """
        if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
            proto = self._get_or_create_tuple_proto(tensor.op)
            # We can't mutate an element of old_proto.tuple_shardings, so create
            # a new proto.
            tuple_shardings = list(proto.tuple_shardings)
            tuple_shardings[tensor.value_index] = self._proto
            proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                            tuple_shardings=tuple_shardings)
        else:
            proto = self._proto
        attr_value = proto.SerializeToString()

        if use_sharding_op:
            tensor = tf2xla.sharding(tensor, sharding=attr_value)
        # TODO(jmolloy): This need to be seriously revisited before declaring this
        # API available for public use.
        # pylint: disable=protected-access
        tensor.op._set_attr('_XlaSharding',
                            attr_value_pb2.AttrValue(s=attr_value))
        return tensor
Exemple #2
0
def mesh_split(tensor,
               device_mesh,
               tensor_split_dims_mapping,
               use_sharding_op=False):
  """Returns a tensor that is split along multiple dimensions in a device mesh.

  Args:
    tensor: A tf.Tensor to split.
    device_mesh: An np.ndarray describing the topology of the device mesh and
      each element is the ID of the device in the topology.
    tensor_split_dims_mapping: A list of integers that map each tensor axis to
      the device mesh axis along which it is sharded. Its length is the tensor
      rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor
      dimension i. Use -1 for tensor dimensions that are not sharded.
    use_sharding_op: If true, adds a sharding op to set the sharding.

  Raises:
    ValueError: The number of tensor split dimensions is larger than device mesh
      rank.
  """
  sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping)
  if use_sharding_op:
    tensor = tf2xla.sharding(tensor)
  sharding.apply_to_tensor(tensor)
  return tensor
Exemple #3
0
def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False):
  if use_sharding_op:
    tensor = tf2xla.sharding(tensor)
  Sharding.replicate().apply_to_tensor(
      tensor,
      assign_tuple_sharding=assign_tuple_sharding)
  return tensor
    def apply_to_tensor(self,
                        tensor,
                        assign_tuple_sharding=False,
                        use_sharding_op=False,
                        unspecified_dims=None):
        """Applies this Sharding attribute to `tensor`.

    Args:
      tensor: A tf.Tensor to split.
      assign_tuple_sharding: If the sharding type should be a tuple.
      use_sharding_op: Whether to create a sharding op on `tensor`.
      unspecified_dims: An optional list of dimensions unspecified.

    Returns:
      The tensor with Sharding attribute.
    """
        if unspecified_dims:
            assert use_sharding_op and not assign_tuple_sharding
        proto = self._proto
        if use_sharding_op:
            if assign_tuple_sharding:
                proto = self._create_tuple_proto(num_outputs=1)
                tensor = tf2xla.sharding(tensor,
                                         sharding=proto.SerializeToString())
            else:
                tensor = tf2xla.sharding(tensor,
                                         sharding=proto.SerializeToString(),
                                         unspecified_dims=unspecified_dims
                                         or [])
        elif assign_tuple_sharding or len(tensor.op.outputs) > 1:
            proto = self._get_or_create_tuple_proto(tensor.op)
            # We can't mutate an element of old_proto.tuple_shardings, so create
            # a new proto.
            tuple_shardings = list(proto.tuple_shardings)
            tuple_shardings[tensor.value_index] = self._proto
            proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                            tuple_shardings=tuple_shardings)

        # TODO(jmolloy): This need to be seriously revisited before declaring this
        # API available for public use.
        # pylint: disable=protected-access
        tensor.op._set_attr(
            '_XlaSharding',
            attr_value_pb2.AttrValue(s=proto.SerializeToString()))
        return tensor
Exemple #5
0
def assign_device(tensor,
                  device,
                  assign_tuple_sharding=False,
                  use_sharding_op=False):
    """Returns a tensor that has AssignDevice sharding attribute."""
    if use_sharding_op:
        tensor = tf2xla.sharding(tensor)

    Sharding.assign_device(device).apply_to_tensor(
        tensor, assign_tuple_sharding=assign_tuple_sharding)
    return tensor
Exemple #6
0
def partial_tile(tensor, tile_assignment, use_sharding_op=False):
  """Returns a tensor that has tiled sharding.

  Args:
    tensor: A tf.Tensor to shard.
    tile_assignment: An np.ndarray describing the topology of the tiling and
      which device will compute which part of the topology. It must have one
      more dimension than tensor, and the last dimension represents partially
      replicated tiles.
    use_sharding_op: If true, adds a sharding op to set the sharding.
  """
  if use_sharding_op:
    tensor = tf2xla.sharding(tensor)
  Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor)
  return tensor
Exemple #7
0
def tile(tensor,
         tile_assignment,
         assign_tuple_sharding=False,
         use_sharding_op=False):
    """Returns a tensor that has tiled sharding.

  Args:
    tensor: A tf.Tensor to shard.
    tile_assignment: An np.ndarray describing the topology of the tiling and
      which device will compute which part of the topology.
    assign_tuple_sharding: If the sharding type should be a tuple.
    use_sharding_op: If true, adds a sharding op to set the sharding.
  """
    if use_sharding_op:
        tensor = tf2xla.sharding(tensor)
    Sharding.tile(tile_assignment).apply_to_tensor(
        tensor, assign_tuple_sharding=assign_tuple_sharding)
    return tensor
Exemple #8
0
def split(tensor,
          split_dimension,
          num_devices,
          assign_tuple_sharding=False,
          use_sharding_op=False):
    """Returns a tensor that is split along the given dimension.

  Args:
    tensor: A tf.Tensor to split.
    split_dimension: The dimension to split.
    num_devices: The number of devices to partition the dimension.
    assign_tuple_sharding: If the sharding type should be a tuple.
    use_sharding_op: If true, adds a sharding op to set the sharding.
  """
    if use_sharding_op:
        tensor = tf2xla.sharding(tensor)
    Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor(
        tensor, assign_tuple_sharding=assign_tuple_sharding)
    return tensor
Exemple #9
0
def copy_sharding(from_tensor, to_tensor, use_sharding_op=False):
  """Copies the a tensor's sharding to another.

  Args:
    from_tensor: Source tensor. Must be the sole output of an op.
    to_tensor: the tensor the annotate with the copy.
    use_sharding_op: whether to create a sharding op on `to_tensor`.

  Returns:
    A tensor with sharding annotation copied from `from_tensor`.
  """
  sharding = get_op_sharding(from_tensor.op)
  if sharding is None:
    return to_tensor

  if use_sharding_op:
    to_tensor = tf2xla.sharding(to_tensor)
  attr_value = attr_value_pb2.AttrValue(s=sharding)
  # pylint: disable=protected-access
  to_tensor.op._set_attr('_XlaSharding', attr_value)
  return to_tensor