示例#1
0
 def _create_tuple_proto(self, op):
   shardings = [
       xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
       for _ in op.outputs
   ]
   return xla_data_pb2.OpSharding(
       type=xla_data_pb2.OpSharding.TUPLE, tuple_shardings=shardings)
示例#2
0
    def testCreateSlotWithCustomSplitXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom split XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5, 10.0, 15.1], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.split(slot,
                                          split_dimension=0,
                                          num_devices=4,
                                          use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.OTHER,
                                        tile_assignment_dimensions=[4],
                                        tile_assignment_devices=range(4)))
示例#3
0
    def testCreateSlotWithCustomReplicatedXlaSharding(self):
        # slot_creator is used only in optimizer V1.
        # We insert our own custom replicated XLA sharding that overrides the SPMD
        # sharding copied over by the slot_creator.
        with ops.Graph().as_default(), self.cached_session():
            v = variables.Variable([1.0, 2.5], name="var")
            v = xla_sharding.mesh_split(v,
                                        np.array([0, 1]), [0],
                                        use_sharding_op=False)
            with ops.control_dependencies(None):
                slot = slot_creator.create_zeros_slot(v,
                                                      name="slot",
                                                      dtype=dtypes.float64,
                                                      copy_xla_sharding=True)
                slot = xla_sharding.replicate(slot, use_sharding_op=False)

            self.assertNotEqual(xla_sharding.get_tensor_sharding(v),
                                xla_sharding.get_tensor_sharding(slot))

            slot_sharding = xla_sharding.get_tensor_sharding(slot)
            slot_proto = xla_data_pb2.OpSharding()
            slot_proto.ParseFromString(slot_sharding)
            self.assertEqual(
                slot_proto,
                xla_data_pb2.OpSharding(
                    type=xla_data_pb2.OpSharding.REPLICATED))
示例#4
0
    def apply_to_tensor(self,
                        tensor,
                        assign_tuple_sharding=False,
                        use_sharding_op=False):
        """Applies this Sharding attribute to `tensor`.

    Args:
      tensor: A tf.Tensor to split.
      assign_tuple_sharding: If the sharding type should be a tuple.
      use_sharding_op: whether to create a sharding op on `tensor`.

    Returns:
      The tensor with Sharding attribute.
    """
        if len(tensor.op.outputs) > 1 or assign_tuple_sharding:
            proto = self._get_or_create_tuple_proto(tensor.op)
            # We can't mutate an element of old_proto.tuple_shardings, so create
            # a new proto.
            tuple_shardings = list(proto.tuple_shardings)
            tuple_shardings[tensor.value_index] = self._proto
            proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                            tuple_shardings=tuple_shardings)
        else:
            proto = self._proto
        attr_value = proto.SerializeToString()

        if use_sharding_op:
            tensor = tf2xla.sharding(tensor, sharding=attr_value)
        # TODO(jmolloy): This need to be seriously revisited before declaring this
        # API available for public use.
        # pylint: disable=protected-access
        tensor.op._set_attr('_XlaSharding',
                            attr_value_pb2.AttrValue(s=attr_value))
        return tensor
示例#5
0
def ipu_shard(index):
  """Control sharding for a set of operations.

  Provides a scope which targets operations onto a particular shard (IPU) of a
  multi-IPU sharded device.

  Args:
    index: The index of the IPU on which to place the enclosed operations.

  Returns:
     A context
  """

  if hasattr(index, '__iter__'):
    ipus = index
  else:
    ipus = [index]

  proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL,
                                  tile_assignment_devices=ipus)

  attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
  attrs = {"_XlaSharding": attr_value}

  # pylint: disable=protected-access
  with ops.get_default_graph()._attr_scope(attrs):
    yield
示例#6
0
  def tile(cls, tile_assignment):
    """Returns a Tiled sharding attribute.

    This causes an op to be partially computed on multiple cores in the
    XLA device.

    Args:
      tile_assignment: An np.ndarray describing the topology of the tiling and
        which device will compute which part of the topology.

    Raises:
      TypeError: tile_assignment was not of np.array type.

    TODO(jmolloy): This concept is nefarious and is not
    something we really want to expose to users (especially as the
    contract for tile_assignment is very strict).
    """
    if not isinstance(tile_assignment, _np.ndarray):
      raise TypeError('Tile assignment must be of type np.ndarray')
    dims = list(tile_assignment.shape)
    flattened_devices = tile_assignment.reshape(-1, order='C')
    return Sharding(
        proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_assignment_dimensions=dims,
            tile_assignment_devices=list(flattened_devices)))
示例#7
0
  def split(cls, tensor, split_dimension, num_devices, input_shape=None):
    """Returns a Sharding that splits a tensor across a dimension.

    This creates a Tiled attribute, similar to tile(), but easier to use for the
    common case of tiling a tensor N ways in one dimension.

    Args:
      tensor: A tf.Tensor to split.
      split_dimension: The dimension number to split.
      num_devices: The number of cores to split `tensor` over.
      input_shape: The shape of the original tensor.

    Raises:
      ValueError: The tensor to split was smaller in the split dimension than
        the number of devices to split over.
    """
    if input_shape:
      shape = input_shape
    else:
      shape = tensor.shape.as_list()
    if (shape[split_dimension] is not None and
        shape[split_dimension] < num_devices):
      raise ValueError('Split dimension was smaller than the required number '
                       'of splits: shape=%r, dimension=%r, num_devices=%r' %
                       (shape, split_dimension, num_devices))

    tile_assignment_dims = [1] * len(shape)
    tile_assignment_dims[split_dimension] = num_devices

    return Sharding(
        proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_assignment_dimensions=tile_assignment_dims,
            tile_assignment_devices=range(num_devices)))
示例#8
0
  def manual(cls):
    """Returns a manuall sharding attribute.

    This means the op is manually partitioned by the user and XLA will not
    change the shapes.
    """
    return Sharding(
        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MANUAL))
示例#9
0
  def replicate(cls):
    """Returns a replicated sharding attribute.

    This causes an op to be computed in its entirety independently on all
    cores in the XLA device.
    """
    return Sharding(
        proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED))
示例#10
0
 def _get_or_create_tuple_proto(self, op):
   try:
     attr = op.get_attr('_XlaSharding')
     proto = xla_data_pb2.OpSharding()
     proto.ParseFromString(attr)
     return proto
   except ValueError:
     return self._create_tuple_proto(op)
示例#11
0
    def assign_device(cls, core):
        """Returns an AssignDevice sharding attribute.

    This causes an op to be computed in its entirety only on one core in
    the XLA device.
    Args:
      core: The core to assign this Op to.
    """
        return Sharding(
            proto=xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL,
                                          tile_assignment_dimensions=[1],
                                          tile_assignment_devices=[core]))
示例#12
0
def GetVarSharding(var: tf.Variable) -> TensorShardingSpec:
    """Returns the sharding directly attached to a variable."""
    sharding = xla_sharding.get_op_sharding(var.op)
    if not sharding:
        return TensorShardingSpec.ReplicatedSpec()

    proto = xla_data_pb2.OpSharding()
    proto.ParseFromString(sharding)
    spec_without_padding = TensorShardingSpec.FromXlaOpSharding(proto)
    # Consider uneven padding.
    return TensorShardingSpec.FromFullShape(
        [int(d) for d in var.shape], spec_without_padding.split_dims_mapping,
        spec_without_padding.device_mesh)
示例#13
0
def set_ipu_shard(op, index):
    """Set shard index for op.

  Args:
    op(tf.Operation): Operator
    index(int): IPU index

  """
    proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.MAXIMAL,
                                    tile_assignment_devices=[index])

    attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
    op._set_attr(sharding._XLA_SHARDING, attr_value)
示例#14
0
def ipu_shard(index):

  ipus = []
  if hasattr(index, '__iter__'):
    ipus = index
  else:
    ipus = [index]

  proto = xla_data_pb2.OpSharding(
      type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=ipus)

  attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
  attrs = {"_XlaSharding": attr_value}

  # pylint: disable=protected-access
  with ops.get_default_graph()._attr_scope(attrs):
    yield
示例#15
0
    def subgroup_tile(cls, tile_assignment, subgroup_modes):
        """Returns a subgroup manual sharding attribute.

    This is similar to tile(), but tile_assignment has one or more dimension
    than the tensor, and subgroup_modes define the sharding types in the last
    dimensions of tile_assignment.

    Args:
      tile_assignment: An np.ndarray describing the topology of the tiling and
        which device will compute which part of the topology.
      subgroup_modes: sharding types for the dimension more than the tensor
        shape rank.

    Raises:
      TypeError: tile_assignment was not of np.array type or subgroup_modes
        has unsupported sharding type.
    """
        if not isinstance(tile_assignment, _np.ndarray):
            raise TypeError(
                'SubgroupTile assignment must be of type np.ndarray')

        if not isinstance(subgroup_modes, list):
            raise TypeError(
                'subgroup_modes in subgroup manual must be of type list')

        if len(tile_assignment.shape) < len(subgroup_modes):
            raise TypeError(
                'SubgroupTile assignment must have rank larger than'
                ' length of subgroup_modes')

        for sharding_type in subgroup_modes:
            if sharding_type not in [
                    xla_data_pb2.OpSharding.REPLICATED,
                    xla_data_pb2.OpSharding.MANUAL
            ]:
                raise TypeError(
                    'Each sharding_type in subgroup_modes in subgroup manual must '
                    'be of type xla_data_pb2.OpSharding.REPLICATED'
                    ' or xla_data_pb2.OpSharding.MANUAL')
        dims = list(tile_assignment.shape)
        flattened_devices = tile_assignment.reshape(-1, order='C')
        return Sharding(proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_assignment_dimensions=dims,
            tile_assignment_devices=list(flattened_devices),
            last_tile_dims=list(subgroup_modes)))
示例#16
0
 def apply_to_tensor(self, tensor):
     """Applies this Sharding attribute to `tensor`."""
     if len(tensor.op.outputs) > 1:
         proto = self._get_or_create_tuple_proto(tensor.op)
         # We can't mutate an element of old_proto.tuple_shardings, so create
         # a new proto.
         tuple_shardings = list(proto.tuple_shardings)
         tuple_shardings[tensor.value_index] = self._proto
         proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                         tuple_shardings=tuple_shardings)
     else:
         proto = self._proto
     attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
     # TODO(jmolloy): This need to be seriously revisited before declaring this
     # API available for public use.
     # pylint: disable=protected-access
     tensor.op._set_attr('_XlaSharding', attr_value)
示例#17
0
    def apply_to_tensor(self,
                        tensor,
                        assign_tuple_sharding=False,
                        use_sharding_op=False,
                        unspecified_dims=None):
        """Applies this Sharding attribute to `tensor`.

    Args:
      tensor: A tf.Tensor to split.
      assign_tuple_sharding: If the sharding type should be a tuple.
      use_sharding_op: Whether to create a sharding op on `tensor`.
      unspecified_dims: An optional list of dimensions unspecified.

    Returns:
      The tensor with Sharding attribute.
    """
        if unspecified_dims:
            assert use_sharding_op and not assign_tuple_sharding
        proto = self._proto
        if use_sharding_op:
            if assign_tuple_sharding:
                proto = self._create_tuple_proto(num_outputs=1)
                tensor = tf2xla.sharding(tensor,
                                         sharding=proto.SerializeToString())
            else:
                tensor = tf2xla.sharding(tensor,
                                         sharding=proto.SerializeToString(),
                                         unspecified_dims=unspecified_dims
                                         or [])
        elif assign_tuple_sharding or len(tensor.op.outputs) > 1:
            proto = self._get_or_create_tuple_proto(tensor.op)
            # We can't mutate an element of old_proto.tuple_shardings, so create
            # a new proto.
            tuple_shardings = list(proto.tuple_shardings)
            tuple_shardings[tensor.value_index] = self._proto
            proto = xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                            tuple_shardings=tuple_shardings)

        # TODO(jmolloy): This need to be seriously revisited before declaring this
        # API available for public use.
        # pylint: disable=protected-access
        tensor.op._set_attr(
            '_XlaSharding',
            attr_value_pb2.AttrValue(s=proto.SerializeToString()))
        return tensor
示例#18
0
def get_sharding_tile_shape(sharding):
    """Returns the tile assignment shape for a sharded Tensor.

  Args:
    sharding: a serialized OpSharding message describing the layout of a
      sharded Tensor.

  Returns:
    A list, for each dimension of the sharded Tensor, of the number of shards
      into which it has been split. Returns None if the input indicates no tile
      assignments.
  """
    if sharding is None:
        return None
    sharding_message = xla_data_pb2.OpSharding()
    sharding_message.ParseFromString(sharding)
    if sharding_message.tile_assignment_dimensions:
        return sharding_message.tile_assignment_dimensions
    else:
        return None
示例#19
0
  def split(cls, tensor, split_dimension, num_devices):
    """Returns a Sharding that splits a tensor across a dimension.

    This creates a Tiled attribute, similar to tile(), but easier to use for the
    common case of tiling a tensor N ways in one dimension.

    Args:
      tensor: A tf.Tensor to split.
      split_dimension: The dimension number to split.
      num_devices: The number of cores to split `tensor` over.

    Raises:
      ValueError: The tensor to split was smaller in the split dimension than
        the number of devices to split over.
    """
    tensor.shape.assert_is_fully_defined()
    shape = tensor.shape.as_list()
    if shape[split_dimension] < num_devices:
      raise ValueError('Split dimension was smaller than the required number '
                       'of splits: shape=%r, dimension=%r, num_devices=%r',
                       shape, split_dimension, num_devices)

    tile_shape = shape
    tile_shape[split_dimension] = int(
        math.ceil(tile_shape[split_dimension] / num_devices))
    tile_shape_proto = xla_data_pb2.Shape(
        element_type=xla_data_pb2.F32, dimensions=tile_shape)

    tile_assignment_dims = [1] * len(shape)
    tile_assignment_dims[split_dimension] = num_devices

    return Sharding(
        proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_shape=tile_shape_proto,
            tile_assignment_dimensions=tile_assignment_dims,
            tile_assignment_devices=range(num_devices)))
示例#20
0
  def partial_tile(cls, tile_assignment):
    """Returns a partially tiled sharding attribute.

    This is similar to tile(), but tile_assignment has one more dimension than
    the tensor, and tiles in the last dimension of tile_assignment are
    replicated.

    Args:
      tile_assignment: An np.ndarray describing the topology of the tiling and
        which device will compute which part of the topology.

    Raises:
      TypeError: tile_assignment was not of np.array type.
    """
    if not isinstance(tile_assignment, _np.ndarray):
      raise TypeError('PartialTile assignment must be of type np.ndarray')
    dims = list(tile_assignment.shape)
    flattened_devices = tile_assignment.reshape(-1, order='C')
    return Sharding(
        proto=xla_data_pb2.OpSharding(
            type=xla_data_pb2.OpSharding.OTHER,
            tile_assignment_dimensions=dims,
            tile_assignment_devices=list(flattened_devices),
            replicate_on_last_tile_dim=True))
示例#21
0
 def _create_tuple_proto(self, num_outputs):
     shardings = [
         xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.REPLICATED)
     ] * num_outputs
     return xla_data_pb2.OpSharding(type=xla_data_pb2.OpSharding.TUPLE,
                                    tuple_shardings=shardings)
示例#22
0
def set_ipu_shard(op, index):
  proto = xla_data_pb2.OpSharding(
      type=xla_data_pb2.OpSharding.MAXIMAL, tile_assignment_devices=[index])

  attr_value = attr_value_pb2.AttrValue(s=proto.SerializeToString())
  op._set_attr(sharding._XLA_SHARDING, attr_value)