def apply_to_tensor(self, tensor, assign_tuple_sharding=False, use_sharding_op=False): """Applies this Sharding attribute to `tensor`. Args: tensor: A tf.Tensor to split. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: whether to create a sharding op on `tensor`. Returns: The tensor with Sharding attribute. """ if len(tensor.op.outputs) > 1 or assign_tuple_sharding: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. tuple_shardings = list(proto.tuple_shardings) tuple_shardings[tensor.value_index] = self._proto proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto attr_value = proto.SerializeToString() if use_sharding_op: tensor = tf2xla.sharding(tensor, sharding=attr_value) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. # pylint: disable=protected-access tensor.op._set_attr('_XlaSharding', attr_value_pb2.AttrValue(s=attr_value)) return tensor
def mesh_split(tensor, device_mesh, tensor_split_dims_mapping, use_sharding_op=False): """Returns a tensor that is split along multiple dimensions in a device mesh. Args: tensor: A tf.Tensor to split. device_mesh: An np.ndarray describing the topology of the device mesh and each element is the ID of the device in the topology. tensor_split_dims_mapping: A list of integers that map each tensor axis to the device mesh axis along which it is sharded. Its length is the tensor rank, and tensor_split_dims_mapping[i] is device mesh axis for tensor dimension i. Use -1 for tensor dimensions that are not sharded. use_sharding_op: If true, adds a sharding op to set the sharding. Raises: ValueError: The number of tensor split dimensions is larger than device mesh rank. """ sharding = mesh_split_sharding(device_mesh, tensor_split_dims_mapping) if use_sharding_op: tensor = tf2xla.sharding(tensor) sharding.apply_to_tensor(tensor) return tensor
def replicate(tensor, assign_tuple_sharding=False, use_sharding_op=False): if use_sharding_op: tensor = tf2xla.sharding(tensor) Sharding.replicate().apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor
def apply_to_tensor(self, tensor, assign_tuple_sharding=False, use_sharding_op=False, unspecified_dims=None): """Applies this Sharding attribute to `tensor`. Args: tensor: A tf.Tensor to split. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: Whether to create a sharding op on `tensor`. unspecified_dims: An optional list of dimensions unspecified. Returns: The tensor with Sharding attribute. """ if unspecified_dims: assert use_sharding_op and not assign_tuple_sharding proto = self._proto if use_sharding_op: if assign_tuple_sharding: proto = self._create_tuple_proto(num_outputs=1) tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString()) else: tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString(), unspecified_dims=unspecified_dims or []) elif assign_tuple_sharding or len(tensor.op.outputs) > 1: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. tuple_shardings = list(proto.tuple_shardings) tuple_shardings[tensor.value_index] = self._proto proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. # pylint: disable=protected-access tensor.op._set_attr( '_XlaSharding', attr_value_pb2.AttrValue(s=proto.SerializeToString())) return tensor
def assign_device(tensor, device, assign_tuple_sharding=False, use_sharding_op=False): """Returns a tensor that has AssignDevice sharding attribute.""" if use_sharding_op: tensor = tf2xla.sharding(tensor) Sharding.assign_device(device).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor
def partial_tile(tensor, tile_assignment, use_sharding_op=False): """Returns a tensor that has tiled sharding. Args: tensor: A tf.Tensor to shard. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. It must have one more dimension than tensor, and the last dimension represents partially replicated tiles. use_sharding_op: If true, adds a sharding op to set the sharding. """ if use_sharding_op: tensor = tf2xla.sharding(tensor) Sharding.partial_tile(tile_assignment).apply_to_tensor(tensor) return tensor
def tile(tensor, tile_assignment, assign_tuple_sharding=False, use_sharding_op=False): """Returns a tensor that has tiled sharding. Args: tensor: A tf.Tensor to shard. tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: If true, adds a sharding op to set the sharding. """ if use_sharding_op: tensor = tf2xla.sharding(tensor) Sharding.tile(tile_assignment).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor
def split(tensor, split_dimension, num_devices, assign_tuple_sharding=False, use_sharding_op=False): """Returns a tensor that is split along the given dimension. Args: tensor: A tf.Tensor to split. split_dimension: The dimension to split. num_devices: The number of devices to partition the dimension. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: If true, adds a sharding op to set the sharding. """ if use_sharding_op: tensor = tf2xla.sharding(tensor) Sharding.split(tensor, split_dimension, num_devices).apply_to_tensor( tensor, assign_tuple_sharding=assign_tuple_sharding) return tensor
def copy_sharding(from_tensor, to_tensor, use_sharding_op=False): """Copies the a tensor's sharding to another. Args: from_tensor: Source tensor. Must be the sole output of an op. to_tensor: the tensor the annotate with the copy. use_sharding_op: whether to create a sharding op on `to_tensor`. Returns: A tensor with sharding annotation copied from `from_tensor`. """ sharding = get_op_sharding(from_tensor.op) if sharding is None: return to_tensor if use_sharding_op: to_tensor = tf2xla.sharding(to_tensor) attr_value = attr_value_pb2.AttrValue(s=sharding) # pylint: disable=protected-access to_tensor.op._set_attr('_XlaSharding', attr_value) return to_tensor