示例#1
0
    def _experimental_split_to_logical_devices(self, tensor,
                                               partition_dimensions):
        """See `DistributionStrategy.experimental_split_to_logical_devices`."""
        num_logical_devices_per_replica = self._tpu_devices.shape[1]
        num_partition_splits = np.prod(partition_dimensions)
        input_shape = tensor.shape
        tensor_rank = len(input_shape)

        if tensor_rank != len(partition_dimensions):
            raise ValueError("Length of `partition_dimensions` ({}) must be  "
                             "equal to the rank of `x` ({}).".format(
                                 len(partition_dimensions), tensor_rank))

        for dim_index, dim_size in enumerate(input_shape):
            if dim_size is None:
                continue

            split_size = partition_dimensions[dim_index]
            if dim_size % split_size != 0:
                raise ValueError("Tensor shape at dimension ({}) must be "
                                 "divisible by corresponding value specified "
                                 "by `partition_dimensions` ({}).".format(
                                     dim_index, split_size))

        if num_partition_splits != num_logical_devices_per_replica:
            raise ValueError(
                "Number of logical devices ({}) does not match the "
                "number of partition splits specified ({}).".format(
                    num_logical_devices_per_replica, num_partition_splits))

        tile_assignment = np.arange(num_partition_splits).reshape(
            partition_dimensions)
        return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
 def tile_helper(tensor):
     self.assertIsNone(xla_sharding.get_tensor_sharding(tensor))
     tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6]))
     self.assertIsInstance(tiled_tensor, ops.Tensor)
     tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor)
     tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding)
     # This is the shape of the tile assignment [2, 1, 6]
     expected_shape = [3]
     self.assertEqual(expected_shape, tile_shape)
     return tiled_tensor
示例#3
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
示例#4
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
示例#5
0
    def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
        """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
        if dims is None:
            return xla_sharding.replicate(tensor)
        elif np.prod(dims) == 1:
            return xla_sharding.assign_device(tensor, 0)
        else:
            tile_shape = np.array(tensor.shape.as_list()) // dims
            tile_assignment = np.arange(np.prod(dims)).reshape(dims)
            return xla_sharding.tile(
                tensor=tensor,
                tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
                    dtype=np.dtype(tensor.dtype.as_numpy_dtype),
                    shape_tuple=tile_shape),
                tile_assignment=tile_assignment)
示例#6
0
  def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
    """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
    if dims is None:
      return xla_sharding.replicate(tensor)
    elif np.prod(dims) == 1:
      return xla_sharding.assign_device(tensor, 0)
    else:
      tile_shape = np.array(tensor.shape.as_list()) // dims
      tile_assignment = np.arange(np.prod(dims)).reshape(dims)
      return xla_sharding.tile(
          tensor=tensor,
          tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
              dtype=np.dtype(tensor.dtype.as_numpy_dtype),
              shape_tuple=tile_shape),
          tile_assignment=tile_assignment)