def __init__(self, element_type, dimensions, layout=None): """Creates a new XLA Shape. Args: element_type: element type from xla_data_pb2. dimensions: sequence of dimensions sizes (integers), or sequence of Shapes in the case of a tuple, i.e. when element_type is TUPLE. layout: optional minor_to_major sequence for layout. If not given, the default major-to-minor layout is used. Raises: ValueError: if element_type is TUPLE but dimensions are not Shape objects. """ self.message = xla_data_pb2.Shape() self.message.element_type = element_type if element_type == xla_data_pb2.TUPLE: if not all(isinstance(subshape, Shape) for subshape in dimensions): raise ValueError( 'XLA tuple requires sequence of Shape objects as dimensions' ) self._tuple_shapes = tuple(dimensions) for component_shape in self._tuple_shapes: component_message = self.message.tuple_shapes.add() component_message.CopyFrom(component_shape.message) else: self.message.dimensions.extend(dimensions) if layout is None: layout = list(reversed(range(len(dimensions)))) self.message.layout.format = xla_data_pb2.DENSE self.message.layout.minor_to_major.extend(layout)
def split(cls, tensor, split_dimension, num_devices): """Returns a Sharding that splits a tensor across a dimension. This creates a Tiled attribute, similar to tile(), but easier to use for the common case of tiling a tensor N ways in one dimension. Args: tensor: A tf.Tensor to split. split_dimension: The dimension number to split. num_devices: The number of cores to split `tensor` over. Raises: ValueError: The tensor to split was smaller in the split dimension than the number of devices to split over. """ shape = tensor.shape.as_list() if (shape[split_dimension] is not None and shape[split_dimension] < num_devices): raise ValueError('Split dimension was smaller than the required number ' 'of splits: shape=%r, dimension=%r, num_devices=%r', shape, split_dimension, num_devices) tile_shape = shape tile_shape[split_dimension] = int( math.ceil(tile_shape[split_dimension] / num_devices)) tile_shape_proto = xla_data_pb2.Shape( element_type=xla_data_pb2.F32, dimensions=tile_shape) tile_assignment_dims = [1] * len(shape) tile_assignment_dims[split_dimension] = num_devices return Sharding( proto=xla_data_pb2.OpSharding( type=xla_data_pb2.OpSharding.OTHER, tile_shape=tile_shape_proto, tile_assignment_dimensions=tile_assignment_dims, tile_assignment_devices=range(num_devices)))