예제 #1
0
    def test_sharded_tensor_sizes(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        # Test with *args
        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        # Test with single *args
        sharded_tensor = _sharded_tensor.empty(spec, 10)
        self.assertEqual(torch.Size([10]), sharded_tensor.size())

        # Test with list
        sharded_tensor = _sharded_tensor.empty(spec, [10, 20])
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        # Test with tuple
        sharded_tensor = _sharded_tensor.empty(spec, (10, 20))
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        with self.assertRaises(TypeError):
            sharded_tensor = _sharded_tensor.empty(spec, 'foo')
예제 #2
0
    def test_sharded_tensor_metadata(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size)
        self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
        self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
        self.assertEqual(False, sharded_tensor_metadata.requires_grad)
        self.assertEqual(torch.contiguous_format,
                         sharded_tensor_metadata.memory_format)
        self.assertEqual(False, sharded_tensor_metadata.pin_memory)

        sharded_tensor = _sharded_tensor.empty(spec,
                                               10,
                                               20,
                                               requires_grad=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.requires_grad)

        sharded_tensor = _sharded_tensor.empty(spec,
                                               10,
                                               20,
                                               dtype=torch.double)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.double, sharded_tensor_metadata.dtype)

        # Need CPU for pin_memory
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cpu",
                "rank:1/cpu",
                "rank:2/cpu",
                "rank:3/cpu",
            ],
        )

        sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.pin_memory)
예제 #3
0
    def test_new_group(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cuda:3",
            ),
        ])

        pg = dist.new_group(ranks=[1, 2, 3])

        if self.rank >= 1:
            sharded_tensor = _sharded_tensor.empty(spec,
                                                   10,
                                                   5,
                                                   process_group=pg)
            self.assertEqual((10, 5), sharded_tensor.size())
            if self.rank == 1 or self.rank == 3:
                # Verify local shard.
                local_shard = sharded_tensor.local_shards()[0]
                self.assertEqual(torch.device(f'cuda:{self.rank}'),
                                 local_shard.tensor.device)
                self.assertEqual((5, 5), local_shard.tensor.size())

                # Verify local shard metadata.
                self.assertEqual((self.rank // 2 * 5, 0),
                                 local_shard.metadata.shard_offsets)
                self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
                self.assertEqual(f'rank:{self.rank - 1}/cuda:{self.rank}',
                                 local_shard.metadata.placement)

            # Verify global metadata.
            sharding_metadata = sharded_tensor.sharding_metadata()
            self.assertEqual(2, len(sharding_metadata))
            for rank, shard_metadata in enumerate(sharding_metadata):
                self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
                self.assertEqual((5, 5), shard_metadata.shard_lengths)
                self.assertEqual(f'rank:{rank * 2}/cuda:{rank * 2 + 1}',
                                 shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            if self.rank == 1 or self.rank == 3:
                self.assertEqual(1, len(remote_shards))
            else:
                self.assertEqual(2, len(remote_shards))

            owners = {}
            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))

                for remote_shard in shards:
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    shard = remote_shard.to_here()
                    self.assertEqual((5, 5), shard.tensor.size())
예제 #4
0
    def test_partial_world_size(self):
        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 5)
        self.assertEqual((10, 5), sharded_tensor.size())
        if self.rank <= 1:
            self.assertEqual(1, len(sharded_tensor.local_shards()))
        else:
            self.assertEqual(0, len(sharded_tensor.local_shards()))

        if self.rank <= 1:
            # Verify local shard.
            local_shard = sharded_tensor.local_shards()[0]
            self.assertEqual(torch.device(f'cuda:{self.rank}'),
                             local_shard.tensor.device)
            self.assertEqual((5, 5), local_shard.tensor.size())

            # Verify local shard metadata.
            self.assertEqual((self.rank * 5, 0),
                             local_shard.metadata.shard_offsets)
            self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
            self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                             local_shard.metadata.placement)

        # Verify global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(2, len(shards_metadata))
        for rank, shard_metadata in enumerate(shards_metadata):
            self.assertEqual((rank * 5, 0), shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{rank}/cuda:{rank}',
                             shard_metadata.placement)

        # Validate remote shards.
        remote_shards = sharded_tensor.remote_shards
        if self.rank <= 1:
            self.assertEqual(1, len(remote_shards))
        else:
            self.assertEqual(2, len(remote_shards))

        for rpc_rank, shards in remote_shards.items():
            self.assertEqual(1, len(shards))

            for remote_shard in shards:
                self.assertEqual(rpc_rank, remote_shard.owner().id)
                shard = remote_shard.to_here()
                self.assertEqual((5, 5), shard.tensor.size())
예제 #5
0
    def test_multiple_local_shards(self):
        self.init_pg()

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        self.assertEqual((10, 10), sharded_tensor.size())

        if self.rank <= 1:
            self.assertEqual(2, len(sharded_tensor.local_shards()))

            # Verify local shards.
            for idx, local_shard in enumerate(sharded_tensor.local_shards()):
                self.assertEqual(torch.device(f'cuda:{self.rank}'),
                                 local_shard.tensor.device)
                self.assertEqual((5, 5), local_shard.tensor.size())

                # Verify local shard metadata.
                self.assertEqual((idx * 5, self.rank * 5),
                                 local_shard.metadata.shard_offsets)
                self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
                self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                                 local_shard.metadata.placement)
        else:
            self.assertEqual(0, len(sharded_tensor.local_shards()))

        # Verify global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(4, len(sharding_metadata))
        for shard_rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual((shard_rank // 2 * 5, (shard_rank % 2) * 5),
                             shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_rank % 2}/cuda:{shard_rank % 2}',
                             shard_metadata.placement)
예제 #6
0
    def test_new_group(self):

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:2",
                "rank:2/cuda:3",
            ],
        )

        pg = dist.new_group(ranks=[1, 2, 3])
        if self.rank >= 1:
            sharded_tensor = _sharded_tensor.empty(spec,
                                                   10,
                                                   20,
                                                   process_group=pg)

            # Validate local shard.
            local_shards = sharded_tensor.local_shards()
            if self.rank >= 2:
                self.assertEqual(1, len(local_shards))
                local_shard = local_shards[0].tensor
                self.assertEqual(torch.device(f"cuda:{self.rank}"),
                                 local_shard.device)
                self.assertEqual((5, 20), local_shard.size())
            else:
                self.assertEqual(0, len(local_shards))

            # Validate global metadata.
            sharded_tensor_metadata = sharded_tensor.metadata()
            shards_metadata = sharded_tensor_metadata.shards_metadata
            self.assertEqual(2, len(shards_metadata))

            for shard_rank, shard_metadata in enumerate(shards_metadata):
                self.assertEqual([shard_rank * 5, 0],
                                 shard_metadata.shard_offsets)
                self.assertEqual([5, 20], shard_metadata.shard_lengths)
                self.assertEqual(
                    f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}',
                    shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            if self.rank >= 2:
                self.assertEqual(1, len(remote_shards))
            else:
                self.assertEqual(2, len(remote_shards))

            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))
                for remote_shard in shards:
                    shard = remote_shard.to_here()
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    self.assertEqual(f'rank:{rpc_rank - 1}/cuda:{rpc_rank}',
                                     shard.metadata.placement)
                    self.assertEqual((5, 20), shard.tensor.size())
예제 #7
0
    def test_complete_world_size(self):

        for dim in [0, -2]:
            spec = ChunkShardingSpec(
                dim=dim,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                ],
            )
            sharded_tensor = _sharded_tensor.empty(spec, 10, 20)

            # Validate local shard.
            local_shards = sharded_tensor.local_shards()
            self.assertEqual(1, len(local_shards))
            local_shard = local_shards[0].tensor
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.device)
            if self.rank == 3:
                self.assertEqual((1, 20), local_shard.size())
            else:
                self.assertEqual((3, 20), local_shard.size())

            # Validate global metadata.
            sharded_tensor_metadata = sharded_tensor.metadata()
            shards_metadata = sharded_tensor_metadata.shards_metadata
            self.assertEqual(4, len(shards_metadata))

            for rank, shard_metadata in enumerate(shards_metadata):
                self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets)
                if rank == 3:
                    self.assertEqual([1, 20], shard_metadata.shard_lengths)
                else:
                    self.assertEqual([3, 20], shard_metadata.shard_lengths)
                self.assertEqual(f'rank:{rank}/cuda:{rank}',
                                 shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            self.assertEqual(3, len(remote_shards))

            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))
                for remote_shard in shards:
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    shard = remote_shard.to_here()
                    self.assertEqual(f'rank:{rpc_rank}/cuda:{rpc_rank}',
                                     shard.metadata.placement)
                    if rpc_rank == 3:
                        self.assertEqual((1, 20), shard.tensor.size())
                    else:
                        self.assertEqual((3, 20), shard.tensor.size())
예제 #8
0
    def test_grid_sharding(self):
        self.init_pg()

        spec = EnumerableShardingSpec([
            ShardMetadata(
                shard_offsets=[0, 0],
                shard_lengths=[5, 5],
                placement="rank:0/cuda:0",
            ),
            ShardMetadata(
                shard_offsets=[0, 5],
                shard_lengths=[5, 5],
                placement="rank:1/cuda:1",
            ),
            ShardMetadata(
                shard_offsets=[5, 0],
                shard_lengths=[5, 5],
                placement="rank:2/cuda:2",
            ),
            ShardMetadata(
                shard_offsets=[5, 5],
                shard_lengths=[5, 5],
                placement="rank:3/cuda:3",
            )
        ])

        sharded_tensor = _sharded_tensor.empty(spec, 10, 10)
        self.assertEqual((10, 10), sharded_tensor.size())
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        # Verify local shard.
        local_shard = sharded_tensor.local_shards()[0]
        self.assertEqual(torch.device(f'cuda:{self.rank}'),
                         local_shard.tensor.device)
        self.assertEqual((5, 5), local_shard.tensor.size())

        # Verify local shard metadata.
        self.assertEqual((self.rank // 2 * 5, (self.rank % 2) * 5),
                         local_shard.metadata.shard_offsets)
        self.assertEqual((5, 5), local_shard.metadata.shard_lengths)
        self.assertEqual(f'rank:{self.rank}/cuda:{self.rank}',
                         local_shard.metadata.placement)

        # Verify global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(4, len(sharding_metadata))
        for rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual((rank // 2 * 5, (rank % 2) * 5),
                             shard_metadata.shard_offsets)
            self.assertEqual((5, 5), shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{rank}/cuda:{rank}',
                             shard_metadata.placement)
예제 #9
0
    def test_multiple_local_shards(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_tensor = _sharded_tensor.empty(spec, 16, 20)

        # Validate local shards.
        local_shards = sharded_tensor.local_shards()
        self.assertEqual(2, len(local_shards))
        for local_shard in local_shards:
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.tensor.device)
            self.assertEqual((2, 20), local_shard.tensor.size())

        # Validate global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(8, len(shards_metadata))

        for shard_idx, shard_metadata in enumerate(shards_metadata):
            self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets)
            self.assertEqual([2, 20], shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}',
                             shard_metadata.placement)

        # Validate remote shards.
        remote_shards = sharded_tensor.remote_shards
        self.assertEqual(3, len(remote_shards))
        owners = {}
        for rpc_rank, shards in remote_shards.items():
            self.assertEqual(2, len(shards))
            for remote_shard in shards:
                shard = remote_shard.to_here()
                self.assertEqual((2, 20), shard.tensor.size())
                self.assertEqual(rpc_rank, remote_shard.owner().id)