def test_new_group(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:3", ), ]) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 5, process_group=pg) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank == 1 or self.rank == 3: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(2, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank == 1 or self.rank == 3: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_partial_world_size(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 5) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(1, len(sharded_tensor.local_shards())) else: self.assertEqual(0, len(sharded_tensor.local_shards())) if self.rank <= 1: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank <= 1: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_with_rpc_names(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="worker0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="worker1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="worker2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="worker3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'worker{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'worker{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_multiple_local_shards(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(2, len(sharded_tensor.local_shards())) # Verify local shards. for idx, local_shard in enumerate(sharded_tensor.local_shards()): self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((idx * 5, self.rank * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) else: self.assertEqual(0, len(sharded_tensor.local_shards())) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for shard_rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}', shard_metadata.placement)
def test_grid_sharding(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
def _init_chunked( self, dims, tensor_init_params: TensorInitParams, ): current_rank = dist.get_rank(self._process_group) sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined] # Validate the sharding spec. if not isinstance(sharding_dim, int): raise ValueError( f"Sharding dim needs to be an integer, found: {sharding_dim}") if sharding_dim >= len(dims) or sharding_dim < -len(dims): raise ValueError(f"Invalid sharding dim: {sharding_dim}") dim_size = dims[sharding_dim] remote_devices = self._sharding_spec.placements # type: ignore[attr-defined] chunks = len(remote_devices) # split_size computed similar to 'torch.chunk' split_size = (dim_size + chunks - 1) // chunks shards_metadata = [] for idx, remote_device in enumerate(remote_devices): rank, local_device = self._parse_and_validate_remote_device( remote_device) # Adjust the sharding dim for this rank. sharded_dim_size = min(dim_size, split_size * (idx + 1)) - split_size * idx if sharded_dim_size > 0: # Build sharding_metadata. # deepcopy for modification. rank_dims = dims.copy() rank_offsets = [0] * len(dims) rank_offsets[sharding_dim] = split_size * idx rank_dims[sharding_dim] = sharded_dim_size shard_metadata = ShardMetadata(rank_offsets, rank_dims, remote_device) shards_metadata.append(shard_metadata) # Build the local shard for the current rank if it is involved in the sharding spec. if current_rank == rank: # Initialize the local shard. local_shard = _create_tensor_from_params( *rank_dims, local_device=local_device, tensor_init_params=tensor_init_params) self._local_shards.append( Shard(local_shard, shard_metadata)) # Build overall metadata self._metadata = ShardedTensorMetadata( shards_metadata, dims, tensor_init_params.tensor_properties, )
def from_tensor_and_offsets(cls, tensor: torch.Tensor, shard_offsets: List[int], rank: int): """ Creates a Shard of a ShardedTensor from a local torch.Tensor, shard_offsets and rank. Args: tensor(torch.Tensor): Local tensor for the shard. shard_offsets(List[int]): List of integers specify the offset of the shard on each dimension. rank(int): Specify the rank for the shard. """ shard_lengths = list(tensor.size()) placement = _remote_device(f"rank:{rank}/{str(tensor.device)}") shard_meta = ShardMetadata(shard_offsets=shard_offsets, shard_lengths=shard_lengths, placement=placement) return Shard(tensor, shard_meta)
def _init_chunked( self, dtype, layout, requires_grad, pin_memory, memory_format, ): current_rank = dist.get_rank(self._process_group) sharding_dim = self._sharding_spec.dim # type: ignore[attr-defined] # Validate the sharding spec. if not isinstance(sharding_dim, int): raise ValueError( f"Sharding dim needs to be an integer, found: {sharding_dim}" ) if sharding_dim >= len(self._dims) or sharding_dim < -len(self._dims): raise ValueError(f"Invalid sharding dim: {sharding_dim}") dim_size = self._dims[sharding_dim] devices = self._sharding_spec.placements # type: ignore[attr-defined] chunks = len(devices) # split_size computed similar to 'torch.chunk' split_size = (dim_size + chunks - 1) // chunks for idx, device in enumerate(devices): if not is_valid_device(device): raise ValueError(f"{device} is not a valid device") rank, local_device = self._parse_and_validate_remote_device(device) # Adjust the sharding dim for this rank. sharded_dim_size = min(dim_size, split_size * (idx + 1)) - split_size * idx if sharded_dim_size > 0: # Build sharding_metadata. # deepcopy for modification. rank_dims = self._dims.copy() rank_offsets = [0] * len(self._dims) rank_offsets[sharding_dim] = split_size * idx rank_dims[sharding_dim] = sharded_dim_size shard_metadata = ShardMetadata(rank_offsets, rank_dims, device) self._sharding_metadata.append(shard_metadata) # Build the local shard for the current rank if it is involved in the sharding spec. if current_rank == rank: # Initialize the local shard. local_shard = torch.empty( *rank_dims, dtype=dtype, layout=layout, device=local_device, requires_grad=requires_grad, memory_format=memory_format, pin_memory=pin_memory, ) self._local_shards.append(Shard(local_shard, shard_metadata))
def test_enumerable_sharding_spec(self): # test valid specs # test row-wise sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) spec.check_tensor(torch.rand(10, 5).size()) # test row and column sharding spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 3], shard_lengths=[3, 3], placement="cuda:1", ), ShardMetadata( shard_offsets=[3, 0], shard_lengths=[3, 3], placement="cuda:2", ), ShardMetadata( shard_offsets=[3, 3], shard_lengths=[3, 3], placement="cuda:3", ), ]) spec.check_tensor(torch.rand(6, 6).size()) # test uneven shard sizes. spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[2, 4], placement="cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_lengths=[4, 2], placement="cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_lengths=[4, 4], placement="cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_lengths=[2, 2], placement="cuda:3", ), ]) spec.check_tensor(torch.rand(6, 6).size()) # test invalid sharding with self.assertRaisesRegex(ValueError, 'not a valid device'): ShardMetadata(shard_offsets=[0], shard_lengths=[1], placement="cuda:foo") with self.assertRaisesRegex(ValueError, 'same number of elements'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_offsets should be >=0'): ShardMetadata(shard_offsets=[-1, 0], shard_lengths=[1, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'shard_lengths should be > 0'): ShardMetadata(shard_offsets=[0, 0], shard_lengths=[0, 1], placement="cuda:0") with self.assertRaisesRegex(ValueError, 'Empty shard list provided'): EnumerableShardingSpec([]) with self.assertRaisesRegex(ValueError, 'Found inconsistent ranks for shards'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[1, 1], placement="cpu"), ShardMetadata(shard_offsets=[0, 0, 0], shard_lengths=[1, 1, 1], placement="cpu"), ]) with self.assertRaisesRegex(ValueError, 'Shards.*overlap'): EnumerableShardingSpec([ ShardMetadata(shard_offsets=[0, 0], shard_lengths=[3, 3], placement="cpu"), ShardMetadata(shard_offsets=[2, 0], shard_lengths=[3, 3], placement="cpu"), ]) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'Rank of tensor is.*but shards rank'): spec.check_tensor(torch.rand(10, 10, 10).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'exceeds tensor dim'): spec.check_tensor(torch.rand(10, 3).size()) spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="cuda:1", ) ]) with self.assertRaisesRegex(ValueError, 'does not match tensor volume'): spec.check_tensor(torch.rand(10, 10).size())
def test_sharded_linear_errors(self): for spec in generate_chunk_sharding_specs_for_test(0): fc1 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc1, "bias", spec) with self.assertRaisesRegex( TypeError, 'input and bias need to be torch.Tensor'): fc1(torch.rand(10, 10).cuda(self.rank)) fc2 = torch.nn.Linear(10, 10).cuda(self.rank) shard_parameter(fc2, "weight", spec) with self.assertRaisesRegex(ValueError, 'Input needs to have at least 1 dim'): fc2(torch.tensor(1).cuda(self.rank)) fc3 = torch.nn.Linear(10, 10).cuda(self.rank) fc3.weight = torch.nn.Parameter( torch.rand(10, 10, 10).cuda(self.rank)) shard_parameter(fc3, "weight", spec) with self.assertRaisesRegex(ValueError, 'Weight needs to have exactly 2 dims'): fc3(torch.rand(10, 10).cuda(self.rank)) fc4 = torch.nn.Linear(10, 10).cuda(self.rank) fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank)) shard_parameter(fc4, "weight", spec) with self.assertRaisesRegex(ValueError, 'Bias needs to have exactly 1 dim'): fc4(torch.rand(10, 10).cuda(self.rank)) fc5 = torch.nn.Linear(7, 10).cuda(self.rank) shard_parameter(fc5, "weight", spec) with self.assertRaisesRegex( ValueError, 'Input dim: 13 does not match appropriate weight dim: 7'): fc5(torch.rand(20, 10, 13).cuda(self.rank)) fc6 = torch.nn.Linear(10, 10).cuda(self.rank) del fc6.weight enumerable_spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_sizes=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_sizes=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_sizes=[5, 5], placement="rank:3/cuda:3", ) ]) fc6.weight = empty(enumerable_spec, 10, 10) with self.assertRaisesRegex( ValueError, 'Only ChunkShardingSpec supported for ShardedTensor ops!'): fc6(torch.rand(10, 10).cuda(self.rank)) fc7 = torch.nn.Linear(10, 80).cuda(self.rank) multiple_local_shard_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:0/cuda:0", "rank:1/cuda:1", "rank:1/cuda:1", "rank:2/cuda:2", "rank:2/cuda:2", "rank:3/cuda:3", "rank:3/cuda:3", ], ) del fc7.weight fc7.weight = empty(multiple_local_shard_spec, 80, 10) with self.assertRaisesRegex(ValueError, 'Only one local shard supported!'): fc7(torch.rand(10, 10).cuda(self.rank))
def test_uneven_shards(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[2, 4], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 4], shard_lengths=[4, 2], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[2, 0], shard_lengths=[4, 4], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[4, 4], shard_lengths=[2, 2], placement="rank:3/cuda:3", ), ]) sharded_tensor = _sharded_tensor.empty(spec, 6, 6) self.assertEqual((6, 6), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) def verify_size(rank, tensor_dims): if rank == 0: self.assertEqual((2, 4), tensor_dims) elif rank == 1: self.assertEqual((4, 2), tensor_dims) elif rank == 2: self.assertEqual((4, 4), tensor_dims) elif rank == 3: self.assertEqual((2, 2), tensor_dims) def verify_offsets(rank, offsets): if rank == 0: self.assertEqual((0, 0), offsets) elif rank == 1: self.assertEqual((0, 4), offsets) elif rank == 2: self.assertEqual((2, 0), offsets) elif rank == 3: self.assertEqual((4, 4), offsets) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) verify_size(self.rank, local_shard.tensor.size()) # Verify local shard metadata. verify_offsets(self.rank, local_shard.metadata.shard_offsets) verify_size(self.rank, local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): verify_offsets(rank, shard_metadata.shard_offsets) verify_size(rank, shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
def test_sharded_tensor_metadata(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.Size([10, 10]), sharded_tensor_metadata.size) self.assertEqual(torch.float, sharded_tensor_metadata.dtype) self.assertEqual(torch.strided, sharded_tensor_metadata.layout) self.assertEqual(False, sharded_tensor_metadata.requires_grad) self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format) self.assertEqual(False, sharded_tensor_metadata.pin_memory) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, requires_grad=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.requires_grad) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, dtype=torch.double) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.double, sharded_tensor_metadata.dtype) # Need CPU for pin_memory spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cpu", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cpu", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cpu", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cpu", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10, pin_memory=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.pin_memory)