def _experimental_split_to_logical_devices(self, tensor, partition_dimensions): """See `DistributionStrategy.experimental_split_to_logical_devices`.""" num_logical_devices_per_replica = self._tpu_devices.shape[1] num_partition_splits = np.prod(partition_dimensions) input_shape = tensor.shape tensor_rank = len(input_shape) if tensor_rank != len(partition_dimensions): raise ValueError("Length of `partition_dimensions` ({}) must be " "equal to the rank of `x` ({}).".format( len(partition_dimensions), tensor_rank)) for dim_index, dim_size in enumerate(input_shape): if dim_size is None: continue split_size = partition_dimensions[dim_index] if dim_size % split_size != 0: raise ValueError("Tensor shape at dimension ({}) must be " "divisible by corresponding value specified " "by `partition_dimensions` ({}).".format( dim_index, split_size)) if num_partition_splits != num_logical_devices_per_replica: raise ValueError( "Number of logical devices ({}) does not match the " "number of partition splits specified ({}).".format( num_logical_devices_per_replica, num_partition_splits)) tile_assignment = np.arange(num_partition_splits).reshape( partition_dimensions) return xla_sharding.tile(tensor, tile_assignment, use_sharding_op=True)
def tile_helper(tensor): self.assertIsNone(xla_sharding.get_tensor_sharding(tensor)) tiled_tensor = xla_sharding.tile(tensor, np.array([2, 1, 6])) self.assertIsInstance(tiled_tensor, ops.Tensor) tiled_sharding = xla_sharding.get_tensor_sharding(tiled_tensor) tile_shape = xla_sharding.get_sharding_tile_shape(tiled_sharding) # This is the shape of the tile assignment [2, 1, 6] expected_shape = [3] self.assertEqual(expected_shape, tile_shape) return tiled_tensor
def _tag_sharding_attribute_for_dequeued_tensor(tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile(tensor=tensor, tile_assignment=tile_assignment)
def _tag_sharding_attribute_for_dequeued_tensor(self, tensor, dims): """Tags appropriate XLA sharding attribute to the dequeued tensor. Args: tensor: The dequeued tensor on TPU. dims: A list of integer describes how the tensor is partitioned. Returns: The same tensor with the xla_sharding attribute. """ if dims is None: return xla_sharding.replicate(tensor) elif np.prod(dims) == 1: return xla_sharding.assign_device(tensor, 0) else: tile_shape = np.array(tensor.shape.as_list()) // dims tile_assignment = np.arange(np.prod(dims)).reshape(dims) return xla_sharding.tile( tensor=tensor, tile_shape=xla_shape.CreateShapeFromDtypeAndTuple( dtype=np.dtype(tensor.dtype.as_numpy_dtype), shape_tuple=tile_shape), tile_assignment=tile_assignment)