def test_sharded_tensor_sizes(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) # Test with *args sharded_tensor = _sharded_tensor.empty(spec, 10, 20) self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) # Test with single *args sharded_tensor = _sharded_tensor.empty(spec, 10) self.assertEqual(torch.Size([10]), sharded_tensor.size()) # Test with list sharded_tensor = _sharded_tensor.empty(spec, [10, 20]) self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) # Test with tuple sharded_tensor = _sharded_tensor.empty(spec, (10, 20)) self.assertEqual(torch.Size([10, 20]), sharded_tensor.size()) with self.assertRaises(TypeError): sharded_tensor = _sharded_tensor.empty(spec, 'foo')
def test_sharded_tensor_metadata(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_tensor = _sharded_tensor.empty(spec, 10, 20) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size) self.assertEqual(torch.float, sharded_tensor_metadata.dtype) self.assertEqual(torch.strided, sharded_tensor_metadata.layout) self.assertEqual(False, sharded_tensor_metadata.requires_grad) self.assertEqual(torch.contiguous_format, sharded_tensor_metadata.memory_format) self.assertEqual(False, sharded_tensor_metadata.pin_memory) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, requires_grad=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.requires_grad) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, dtype=torch.double) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(torch.double, sharded_tensor_metadata.dtype) # Need CPU for pin_memory spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cpu", "rank:1/cpu", "rank:2/cpu", "rank:3/cpu", ], ) sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True) sharded_tensor_metadata = sharded_tensor.metadata() self.assertEqual(True, sharded_tensor_metadata.pin_memory)
def test_new_group(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:3", ), ]) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 5, process_group=pg) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank == 1 or self.rank == 3: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(2, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank == 1 or self.rank == 3: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_partial_world_size(self): spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 5) self.assertEqual((10, 5), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(1, len(sharded_tensor.local_shards())) else: self.assertEqual(0, len(sharded_tensor.local_shards())) if self.rank <= 1: # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank * 5, 0), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank <= 1: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual((5, 5), shard.tensor.size())
def test_multiple_local_shards(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) if self.rank <= 1: self.assertEqual(2, len(sharded_tensor.local_shards())) # Verify local shards. for idx, local_shard in enumerate(sharded_tensor.local_shards()): self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((idx * 5, self.rank * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) else: self.assertEqual(0, len(sharded_tensor.local_shards())) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for shard_rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}', shard_metadata.placement)
def test_new_group(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:1/cuda:2", "rank:2/cuda:3", ], ) pg = dist.new_group(ranks=[1, 2, 3]) if self.rank >= 1: sharded_tensor = _sharded_tensor.empty(spec, 10, 20, process_group=pg) # Validate local shard. local_shards = sharded_tensor.local_shards() if self.rank >= 2: self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) self.assertEqual((5, 20), local_shard.size()) else: self.assertEqual(0, len(local_shards)) # Validate global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(2, len(shards_metadata)) for shard_rank, shard_metadata in enumerate(shards_metadata): self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets) self.assertEqual([5, 20], shard_metadata.shard_lengths) self.assertEqual( f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards if self.rank >= 2: self.assertEqual(1, len(remote_shards)) else: self.assertEqual(2, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: shard = remote_shard.to_here() self.assertEqual(rpc_rank, remote_shard.owner().id) self.assertEqual(f'rank:{rpc_rank - 1}/cuda:{rpc_rank}', shard.metadata.placement) self.assertEqual((5, 20), shard.tensor.size())
def test_complete_world_size(self): for dim in [0, -2]: spec = ChunkShardingSpec( dim=dim, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_tensor = _sharded_tensor.empty(spec, 10, 20) # Validate local shard. local_shards = sharded_tensor.local_shards() self.assertEqual(1, len(local_shards)) local_shard = local_shards[0].tensor self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.device) if self.rank == 3: self.assertEqual((1, 20), local_shard.size()) else: self.assertEqual((3, 20), local_shard.size()) # Validate global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(4, len(shards_metadata)) for rank, shard_metadata in enumerate(shards_metadata): self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets) if rank == 3: self.assertEqual([1, 20], shard_metadata.shard_lengths) else: self.assertEqual([3, 20], shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards self.assertEqual(3, len(remote_shards)) for rpc_rank, shards in remote_shards.items(): self.assertEqual(1, len(shards)) for remote_shard in shards: self.assertEqual(rpc_rank, remote_shard.owner().id) shard = remote_shard.to_here() self.assertEqual(f'rank:{rpc_rank}/cuda:{rpc_rank}', shard.metadata.placement) if rpc_rank == 3: self.assertEqual((1, 20), shard.tensor.size()) else: self.assertEqual((3, 20), shard.tensor.size())
def test_grid_sharding(self): self.init_pg() spec = EnumerableShardingSpec([ ShardMetadata( shard_offsets=[0, 0], shard_lengths=[5, 5], placement="rank:0/cuda:0", ), ShardMetadata( shard_offsets=[0, 5], shard_lengths=[5, 5], placement="rank:1/cuda:1", ), ShardMetadata( shard_offsets=[5, 0], shard_lengths=[5, 5], placement="rank:2/cuda:2", ), ShardMetadata( shard_offsets=[5, 5], shard_lengths=[5, 5], placement="rank:3/cuda:3", ) ]) sharded_tensor = _sharded_tensor.empty(spec, 10, 10) self.assertEqual((10, 10), sharded_tensor.size()) self.assertEqual(1, len(sharded_tensor.local_shards())) # Verify local shard. local_shard = sharded_tensor.local_shards()[0] self.assertEqual(torch.device(f'cuda:{self.rank}'), local_shard.tensor.device) self.assertEqual((5, 5), local_shard.tensor.size()) # Verify local shard metadata. self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5), local_shard.metadata.shard_offsets) self.assertEqual((5, 5), local_shard.metadata.shard_lengths) self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}', local_shard.metadata.placement) # Verify global metadata. sharding_metadata = sharded_tensor.sharding_metadata() self.assertEqual(4, len(sharding_metadata)) for rank, shard_metadata in enumerate(sharding_metadata): self.assertEqual((rank // 2 * 5, (rank % 2) * 5), shard_metadata.shard_offsets) self.assertEqual((5, 5), shard_metadata.shard_lengths) self.assertEqual(f'rank:{rank}/cuda:{rank}', shard_metadata.placement)
def test_multiple_local_shards(self): spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", "rank:0/cuda:0", "rank:1/cuda:1", "rank:2/cuda:2", "rank:3/cuda:3", ], ) sharded_tensor = _sharded_tensor.empty(spec, 16, 20) # Validate local shards. local_shards = sharded_tensor.local_shards() self.assertEqual(2, len(local_shards)) for local_shard in local_shards: self.assertEqual(torch.device(f"cuda:{self.rank}"), local_shard.tensor.device) self.assertEqual((2, 20), local_shard.tensor.size()) # Validate global metadata. sharded_tensor_metadata = sharded_tensor.metadata() shards_metadata = sharded_tensor_metadata.shards_metadata self.assertEqual(8, len(shards_metadata)) for shard_idx, shard_metadata in enumerate(shards_metadata): self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets) self.assertEqual([2, 20], shard_metadata.shard_lengths) self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}', shard_metadata.placement) # Validate remote shards. remote_shards = sharded_tensor.remote_shards self.assertEqual(3, len(remote_shards)) owners = {} for rpc_rank, shards in remote_shards.items(): self.assertEqual(2, len(shards)) for remote_shard in shards: shard = remote_shard.to_here() self.assertEqual((2, 20), shard.tensor.size()) self.assertEqual(rpc_rank, remote_shard.owner().id)