Пример #1
0
 def _experimental_assign_to_logical_device(self, tensor, logical_device_id):
   """See `DistributionStrategy.experimental_assign_to_logical_device`."""
   num_logical_devices_per_replica = self._tpu_devices.shape[1]
   if (logical_device_id < 0 or
       logical_device_id >= num_logical_devices_per_replica):
     raise ValueError("`logical_core_id` to assign must be lower then total "
                      "number of logical devices per replica. Received "
                      "logical device id {} but there are only total of {} "
                      "logical devices in replica.".format(
                          logical_device_id, num_logical_devices_per_replica))
   return xla_sharding.assign_device(tensor, logical_device_id)
Пример #2
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
Пример #3
0
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims):
  """Tags appropriate XLA sharding attribute to the dequeued tensor.

  Args:
    tensor: The dequeued tensor on TPU.
    dims: A list of integer describes how the tensor is partitioned.

  Returns:
    The same tensor with the xla_sharding attribute.
  """
  if dims is None:
    return xla_sharding.replicate(tensor)
  elif np.prod(dims) == 1:
    return xla_sharding.assign_device(tensor, 0)
  else:
    tile_assignment = np.arange(np.prod(dims)).reshape(dims)
    return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
Пример #4
0
    def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
        """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
        if dims is None:
            return xla_sharding.replicate(tensor)
        elif np.prod(dims) == 1:
            return xla_sharding.assign_device(tensor, 0)
        else:
            tile_shape = np.array(tensor.shape.as_list()) // dims
            tile_assignment = np.arange(np.prod(dims)).reshape(dims)
            return xla_sharding.tile(
                tensor=tensor,
                tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
                    dtype=np.dtype(tensor.dtype.as_numpy_dtype),
                    shape_tuple=tile_shape),
                tile_assignment=tile_assignment)
Пример #5
0
  def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims):
    """Tags appropriate XLA sharding attribute to the dequeued tensor.

    Args:
      tensor: The dequeued tensor on TPU.
      dims: A list of integer describes how the tensor is partitioned.

    Returns:
      The same tensor with the xla_sharding attribute.
    """
    if dims is None:
      return xla_sharding.replicate(tensor)
    elif np.prod(dims) == 1:
      return xla_sharding.assign_device(tensor, 0)
    else:
      tile_shape = np.array(tensor.shape.as_list()) // dims
      tile_assignment = np.arange(np.prod(dims)).reshape(dims)
      return xla_sharding.tile(
          tensor=tensor,
          tile_shape=xla_shape.CreateShapeFromDtypeAndTuple(
              dtype=np.dtype(tensor.dtype.as_numpy_dtype),
              shape_tuple=tile_shape),
          tile_assignment=tile_assignment)