def __init__(self, element_type, dimensions, layout=None):
        """Creates a new XLA Shape.

    Args:
      element_type: element type from xla_data_pb2.
      dimensions: sequence of dimensions sizes (integers), or sequence
        of Shapes in the case of a tuple, i.e. when element_type is
        TUPLE.
      layout: optional minor_to_major sequence for layout. If not given, the
        default major-to-minor layout is used.

    Raises:
      ValueError: if element_type is TUPLE but dimensions are not Shape objects.
    """
        self.message = xla_data_pb2.Shape()
        self.message.element_type = element_type
        if element_type == xla_data_pb2.TUPLE:
            if not all(isinstance(subshape, Shape) for subshape in dimensions):
                raise ValueError(
                    'XLA tuple requires sequence of Shape objects as dimensions'
                )
            self._tuple_shapes = tuple(dimensions)
            for component_shape in self._tuple_shapes:
                component_message = self.message.tuple_shapes.add()
                component_message.CopyFrom(component_shape.message)
        else:
            self.message.dimensions.extend(dimensions)
            if layout is None:
                layout = list(reversed(range(len(dimensions))))
            self.message.layout.format = xla_data_pb2.DENSE
            self.message.layout.minor_to_major.extend(layout)
Пример #2
0
  def split(cls, tensor, split_dimension, num_devices):
    """Returns a Sharding that splits a tensor across a dimension.

    This creates a Tiled attribute, similar to tile(), but easier to use for the
    common case of tiling a tensor N ways in one dimension.

    Args:
      tensor: A tf.Tensor to split.
      split_dimension: The dimension number to split.
      num_devices: The number of cores to split `tensor` over.

    Raises:
      ValueError: The tensor to split was smaller in the split dimension than
        the number of devices to split over.
    """
    shape = tensor.shape.as_list()
    if (shape[split_dimension] is not None and
        shape[split_dimension] < num_devices):
      raise ValueError('Split dimension was smaller than the required number '
                       'of splits: shape=%r, dimension=%r, num_devices=%r',
                       shape, split_dimension, num_devices)

    tile_shape = shape
    tile_shape[split_dimension] = int(
        math.ceil(tile_shape[split_dimension] / num_devices))
    tile_shape_proto = xla_data_pb2.Shape(
        element_type=xla_data_pb2.F32, dimensions=tile_shape)

    tile_assignment_dims = [1] * len(shape)
    tile_assignment_dims[split_dimension] = num_devices

    return Sharding(
        proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_shape=tile_shape_proto,
            tile_assignment_dimensions=tile_assignment_dims,
            tile_assignment_devices=range(num_devices)))