def _create_tuple_proto(self, op): shardings = [ xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) for _ in op.outputs ] return xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
def testCreateSlotWithCustomSplitXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom split XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.split(slot, split_dimension=0, num_devices=4, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=[4], tile_assignment_devices=range(4)))
def testCreateSlotWithCustomReplicatedXlaSharding(self): # slot_creator is used only in optimizer V1. # We insert our own custom replicated XLA sharding that overrides the SPMD # sharding copied over by the slot_creator. with ops.Graph().as_default(), self.cached_session(): v = variables.Variable([1.0, 2.5], name="var") v = xla_sharding.mesh_split(v, np.array([0, 1]), [0], use_sharding_op=False) with ops.control_dependencies(None): slot = slot_creator.create_zeros_slot(v, name="slot", dtype=dtypes.float64, copy_xla_sharding=True) slot = xla_sharding.replicate(slot, use_sharding_op=False) self.assertNotEqual(xla_sharding.get_tensor_sharding(v), xla_sharding.get_tensor_sharding(slot)) slot_sharding = xla_sharding.get_tensor_sharding(slot) slot_proto = xla_data_pb2.OpSharding() slot_proto.ParseFromString(slot_sharding) self.assertEqual( slot_proto, xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.REPLICATED))
def apply_to_tensor(self, tensor, assign_tuple_sharding=False, use_sharding_op=False): """Applies this Sharding attribute to `tensor`. Args: tensor: A tf.Tensor to split. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: whether to create a sharding op on `tensor`. Returns: The tensor with Sharding attribute. """ if len(tensor.op.outputs) > 1 or assign_tuple_sharding: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. tuple_shardings = list(proto.tuple_shardings) tuple_shardings[tensor.value_index] = self._proto proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto attr_value = proto.SerializeToString() if use_sharding_op: tensor = tf2xla.sharding(tensor, sharding=attr_value) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. # pylint: disable=protected-access tensor.op._set_attr('_XlaSharding', attr_value_pb2.AttrValue(s=attr_value)) return tensor
def ipu_shard(index): """Control sharding for a set of operations. Provides a scope which targets operations onto a particular shard (IPU) of a multi-IPU sharded device. Args: index: The index of the IPU on which to place the enclosed operations. Returns: A context """ if hasattr(index, '__iter__'): ipus = index else: ipus = [index] proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=ipus) attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) attrs = {"_XlaSharding": attr_value} # pylint: disable=protected-access with ops.get_default_graph()._attr_scope(attrs): yield
def tile(cls, tile_assignment): """Returns a Tiled sharding attribute. This causes an op to be partially computed on multiple cores in the XLA device. Args: tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: TypeError: tile_assignment was not of np.array type. TODO(jmolloy): This concept is nefarious and is not something we really want to expose to users (especially as the contract for tile_assignment is very strict). """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('Tile assignment must be of type np.ndarray') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices)))
def split(cls, tensor, split_dimension, num_devices, input_shape=None): """Returns a Sharding that splits a tensor across a dimension. This creates a Tiled attribute, similar to tile(), but easier to use for the common case of tiling a tensor N ways in one dimension. Args: tensor: A tf.Tensor to split. split_dimension: The dimension number to split. num_devices: The number of cores to split `tensor` over. input_shape: The shape of the original tensor. Raises: ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ if input_shape: shape = input_shape else: shape = tensor.shape.as_list() if (shape[split_dimension] is not None and shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r' % (shape, split_dimension, num_devices)) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices)))
def manual(cls): """Returns a manuall sharding attribute. This means the op is manually partitioned by the user and XLA will not change the shapes. """ return Sharding( proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
def replicate(cls): """Returns a replicated sharding attribute. This causes an op to be computed in its entirety independently on all cores in the XLA device. """ return Sharding( proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
def _get_or_create_tuple_proto(self, op): try: attr = op.get_attr('_XlaSharding') proto = xla_data_pb2.OpSharding() proto.ParseFromString(attr) return proto except ValueError: return self._create_tuple_proto(op)
def assign_device(cls, core): """Returns an AssignDevice sharding attribute. This causes an op to be computed in its entirety only on one core in the XLA device. Args: core: The core to assign this Op to. """ return Sharding( proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_dimensions=[1], tile_assignment_devices=[core]))
def GetVarSharding(var: tf.Variable) -> TensorShardingSpec: """Returns the sharding directly attached to a variable.""" sharding = xla_sharding.get_op_sharding(var.op) if not sharding: return TensorShardingSpec.ReplicatedSpec() proto = xla_data_pb2.OpSharding() proto.ParseFromString(sharding) spec_without_padding = TensorShardingSpec.FromXlaOpSharding(proto) # Consider uneven padding. return TensorShardingSpec.FromFullShape( [int(d) for d in var.shape], spec_without_padding.split_dims_mapping, spec_without_padding.device_mesh)
def set_ipu_shard(op, index): """Set shard index for op. Args: op(tf.Operation): Operator index(int): IPU index """ proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=[index]) attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) op._set_attr(sharding._XLA_SHARDING, attr_value)
def ipu_shard(index): ipus = [] if hasattr(index, '__iter__'): ipus = index else: ipus = [index] proto = xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=ipus) attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) attrs = {"_XlaSharding": attr_value} # pylint: disable=protected-access with ops.get_default_graph()._attr_scope(attrs): yield
def subgroup_tile(cls, tile_assignment, subgroup_modes): """Returns a subgroup manual sharding attribute. This is similar to tile(), but tile_assignment has one or more dimension than the tensor, and subgroup_modes define the sharding types in the last dimensions of tile_assignment. Args: tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. subgroup_modes: sharding types for the dimension more than the tensor shape rank. Raises: TypeError: tile_assignment was not of np.array type or subgroup_modes has unsupported sharding type. """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError( 'SubgroupTile assignment must be of type np.ndarray') if not isinstance(subgroup_modes, list): raise TypeError( 'subgroup_modes in subgroup manual must be of type list') if len(tile_assignment.shape) < len(subgroup_modes): raise TypeError( 'SubgroupTile assignment must have rank larger than' ' length of subgroup_modes') for sharding_type in subgroup_modes: if sharding_type not in [ xla_data_pb2.OpSharding.REPLICATED, xla_data_pb2.OpSharding.MANUAL ]: raise TypeError( 'Each sharding_type in subgroup_modes in subgroup manual must ' 'be of type xla_data_pb2.OpSharding.REPLICATED' ' or xla_data_pb2.OpSharding.MANUAL') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding(proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices), last_tile_dims=list(subgroup_modes)))
def apply_to_tensor(self, tensor): """Applies this Sharding attribute to `tensor`.""" if len(tensor.op.outputs) > 1: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. tuple_shardings = list(proto.tuple_shardings) tuple_shardings[tensor.value_index] = self._proto proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) else: proto = self._proto attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. # pylint: disable=protected-access tensor.op._set_attr('_XlaSharding', attr_value)
def apply_to_tensor(self, tensor, assign_tuple_sharding=False, use_sharding_op=False, unspecified_dims=None): """Applies this Sharding attribute to `tensor`. Args: tensor: A tf.Tensor to split. assign_tuple_sharding: If the sharding type should be a tuple. use_sharding_op: Whether to create a sharding op on `tensor`. unspecified_dims: An optional list of dimensions unspecified. Returns: The tensor with Sharding attribute. """ if unspecified_dims: assert use_sharding_op and not assign_tuple_sharding proto = self._proto if use_sharding_op: if assign_tuple_sharding: proto = self._create_tuple_proto(num_outputs=1) tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString()) else: tensor = tf2xla.sharding(tensor, sharding=proto.SerializeToString(), unspecified_dims=unspecified_dims or []) elif assign_tuple_sharding or len(tensor.op.outputs) > 1: proto = self._get_or_create_tuple_proto(tensor.op) # We can't mutate an element of old_proto.tuple_shardings, so create # a new proto. tuple_shardings = list(proto.tuple_shardings) tuple_shardings[tensor.value_index] = self._proto proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=tuple_shardings) # TODO(jmolloy): This need to be seriously revisited before declaring this # API available for public use. # pylint: disable=protected-access tensor.op._set_attr( '_XlaSharding', attr_value_pb2.AttrValue(s=proto.SerializeToString())) return tensor
def get_sharding_tile_shape(sharding): """Returns the tile assignment shape for a sharded Tensor. Args: sharding: a serialized OpSharding message describing the layout of a sharded Tensor. Returns: A list, for each dimension of the sharded Tensor, of the number of shards into which it has been split. Returns None if the input indicates no tile assignments. """ if sharding is None: return None sharding_message = xla_data_pb2.OpSharding() sharding_message.ParseFromString(sharding) if sharding_message.tile_assignment_dimensions: return sharding_message.tile_assignment_dimensions else: return None
def split(cls, tensor, split_dimension, num_devices): """Returns a Sharding that splits a tensor across a dimension. This creates a Tiled attribute, similar to tile(), but easier to use for the common case of tiling a tensor N ways in one dimension. Args: tensor: A tf.Tensor to split. split_dimension: The dimension number to split. num_devices: The number of cores to split `tensor` over. Raises: ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ tensor.shape.assert_is_fully_defined() shape = tensor.shape.as_list() if shape[split_dimension] < num_devices: raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r', shape, split_dimension, num_devices) tile_shape = shape tile_shape[split_dimension] = int( math.ceil(tile_shape[split_dimension] / num_devices)) tile_shape_proto = xla_data_pb2.Shape( element_type=xla_data_pb2.F32, dimensions=tile_shape) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices)))
def partial_tile(cls, tile_assignment): """Returns a partially tiled sharding attribute. This is similar to tile(), but tile_assignment has one more dimension than the tensor, and tiles in the last dimension of tile_assignment are replicated. Args: tile_assignment: An np.ndarray describing the topology of the tiling and which device will compute which part of the topology. Raises: TypeError: tile_assignment was not of np.array type. """ if not isinstance(tile_assignment, _np.ndarray): raise TypeError('PartialTile assignment must be of type np.ndarray') dims = list(tile_assignment.shape) flattened_devices = tile_assignment.reshape(-1, order='C') return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_assignment_dimensions=dims, tile_assignment_devices=list(flattened_devices), replicate_on_last_tile_dim=True))
def _create_tuple_proto(self, num_outputs): shardings = [ xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED) ] * num_outputs return xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
def set_ipu_shard(op, index): proto = xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=[index]) attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString()) op._set_attr(sharding._XLA_SHARDING, attr_value)