def generate_chunk_sharding_specs_for_test(sharding_dim):
    return [
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        ),
        # Test different ordering. (Case 1)
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        ),
        # Test different ordering. (Case 2)
        ChunkShardingSpec(
            dim=sharding_dim,
            placements=[
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
            ],
        ),
    ]
Esempio n. 2
0
    def test_sharded_linear_colwise(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        self._run_sharded_linear(spec, [5, 17], [17, 12], 0)
        self._run_sharded_linear(spec, [5, 21], [21, 11], 0)
        self._run_sharded_linear(spec, [5, 23], [23, 13], 0)
        self._run_sharded_linear(spec, [5, 15], [15, 14], 0)

        # Test different ordering.
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:1",
                "rank:0/cuda:0",
                "rank:3/cuda:3",
                "rank:2/cuda:2",
            ],
        )

        self._run_sharded_linear(spec, [5, 17], [17, 12], 0)
        self._run_sharded_linear(spec, [5, 21], [21, 11], 0)
        self._run_sharded_linear(spec, [5, 23], [23, 13], 0)
        self._run_sharded_linear(spec, [5, 15], [15, 14], 0)
Esempio n. 3
0
    def test_sharded_linear_rowwise(self):
        spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        # Test even split.
        self._run_sharded_linear(spec, [5, 16], [16, 11], 1)

        # Test uneven split.
        self._run_sharded_linear(spec, [5, 19], [19, 11], 1)
        self._run_sharded_linear(spec, [5, 21], [21, 11], 1)

        # Test different ordering.
        spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )
        self._run_sharded_linear(spec, [5, 16], [16, 11], 1)
        self._run_sharded_linear(spec, [5, 19], [19, 11], 1)
        self._run_sharded_linear(spec, [5, 21], [21, 11], 1)
    def test_invalid_sharding(self):
        self.init_pg()

        spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"])
        pg = dist.new_group(ranks=[2, 3])
        if self.rank < 2:
            with self.assertRaisesRegex(ValueError, 'not part of process group'):
                _sharded_tensor.empty(spec, 10, 20, process_group=pg)

        spec = ChunkShardingSpec(dim='H', placements=["rank:1/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'needs to be an integer'):
            _sharded_tensor.empty(spec, 10, 20)

        for dim in [2, 3, 4, -3, -4, -5]:
            spec = ChunkShardingSpec(dim=dim, placements=["rank:1/cuda:1"])
            with self.assertRaisesRegex(ValueError, 'Invalid sharding dim'):
                _sharded_tensor.empty(spec, 10, 20)

        spec = ChunkShardingSpec(dim=0, placements=["rank:5/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'Invalid rank'):
            _sharded_tensor.empty(spec, 10, 20)

        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)
        tensor = torch.empty(10, 20)
        with self.assertRaisesRegex(RuntimeError, "torch function 'add' not supported for ShardedTensor!"):
            torch.add(sharded_tensor, tensor)

        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'Only torch.strided layout is currently supported'):
            _sharded_tensor.empty(spec, 10, 20, layout=torch.sparse)

        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'Only torch.contiguous_format memory_format is currently supported'):
            _sharded_tensor.empty(spec, 10, 20, memory_format=torch.channels_last)

        spec = ChunkShardingSpec(dim=0, placements=["worker0/cuda:1"])
        with self.assertRaisesRegex(RuntimeError, 'RPC framework needs to be initialized'):
            _sharded_tensor.empty(spec, 10, 20)

        spec = ChunkShardingSpec(dim=0, placements=["rank:0/cuda:1"])
        with self.assertRaisesRegex(RuntimeError, 'RPC was not initialized'):
            st = _sharded_tensor.empty(spec, 10, 20)
            st.remote_shards

        self.init_rpc()

        # ShardedTensor was initialized before RPC.
        with self.assertRaisesRegex(RuntimeError, 'RPC was not initialized'):
            st.remote_shards

        spec = ChunkShardingSpec(dim=0, placements=["workerfoo/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'Invalid worker name'):
            _sharded_tensor.empty(spec, 10, 20)
Esempio n. 5
0
    def test_sharded_tensor_metadata(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.Size([10, 20]), sharded_tensor_metadata.size)
        self.assertEqual(torch.float, sharded_tensor_metadata.dtype)
        self.assertEqual(torch.strided, sharded_tensor_metadata.layout)
        self.assertEqual(False, sharded_tensor_metadata.requires_grad)
        self.assertEqual(torch.contiguous_format,
                         sharded_tensor_metadata.memory_format)
        self.assertEqual(False, sharded_tensor_metadata.pin_memory)

        sharded_tensor = _sharded_tensor.empty(spec,
                                               10,
                                               20,
                                               requires_grad=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.requires_grad)

        sharded_tensor = _sharded_tensor.empty(spec,
                                               10,
                                               20,
                                               dtype=torch.double)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(torch.double, sharded_tensor_metadata.dtype)

        # Need CPU for pin_memory
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cpu",
                "rank:1/cpu",
                "rank:2/cpu",
                "rank:3/cpu",
            ],
        )

        sharded_tensor = _sharded_tensor.empty(spec, 10, 20, pin_memory=True)
        sharded_tensor_metadata = sharded_tensor.metadata()
        self.assertEqual(True, sharded_tensor_metadata.pin_memory)
Esempio n. 6
0
    def test_init_sharded_tensor_with_kaiming_uniform(self):
        """ Test torch.nn.init.kaiming_uniform_(ShardedTensor, a, mode, nonlinearit) """

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        h, w = 8, 2
        expected_h = 2
        expected_device = torch.device(f"cuda:{self.rank}")
        a, mode, nonlinearity = 0, 'fan_in', 'leaky_relu'

        seed = 1234
        dtype = torch.double

        sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        # Clone local tensor to ensure torch.nn.init starts from the same input
        local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
        torch.manual_seed(seed)
        torch.nn.init.kaiming_normal_(sharded_tensor, a=a, mode=mode, nonlinearity=nonlinearity)

        torch.manual_seed(seed)
        torch.nn.init.kaiming_normal_(local_tensor_clone, a=a, mode=mode, nonlinearity=nonlinearity)
        self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
def test_sharded_tensor_state_dict(single_process_pg):
    from torch.distributed._sharded_tensor import empty as sharded_tensor_empty
    from torch.distributed._sharding_spec import ChunkShardingSpec

    class BoringModelWithShardedTensor(BoringModel):
        def __init__(self, spec):
            super().__init__()
            self.sharded_tensor = sharded_tensor_empty(spec, 10, 20)
            self.sharded_tensor.local_shards()[0].tensor.fill_(0)

    spec = ChunkShardingSpec(
        dim=0,
        placements=[
            "rank:0/cpu",
        ],
    )

    m_0 = BoringModelWithShardedTensor(spec)
    m_0.sharded_tensor.local_shards()[0].tensor.fill_(1)
    name_st = ".sharded_tensor" if _TORCH_GREATER_EQUAL_1_11 else "sharded_tensor"
    assert name_st in m_0.state_dict(
    ), 'Expect "sharded_tensor" to appear in the state dict'

    m_1 = BoringModelWithShardedTensor(spec)
    assert not torch.allclose(
        m_1.sharded_tensor.local_shards()[0].tensor,
        m_0.sharded_tensor.local_shards()[0].tensor
    ), "Expect the shards to be different before `m_1` loading `m_0`'s state dict"

    m_1.load_state_dict(m_0.state_dict(), strict=False)
    assert torch.allclose(
        m_1.sharded_tensor.local_shards()[0].tensor,
        m_0.sharded_tensor.local_shards()[0].tensor
    ), "Expect the shards to be same after `m_1` loading `m_0`'s state dict"
Esempio n. 8
0
    def test_multiple_local_shards(self):
        self.init_pg()

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_tensor = _sharded_tensor.empty(spec, 16, 20)

        # Validate local shards.
        local_shards = sharded_tensor.local_shards()
        self.assertEqual(2, len(local_shards))
        for local_shard in local_shards:
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.tensor.device)
            self.assertEqual((2, 20), local_shard.tensor.size())

        # Validate global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(8, len(sharding_metadata))

        for shard_idx, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets)
            self.assertEqual([2, 20], shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}',
                             shard_metadata.placement)
Esempio n. 9
0
    def test_named_params_with_sharded_tensor(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        param_keys = list(sharded_model_params.keys())
        self.assertEqual(len(param_keys), 2)
        self.assertTrue("param" in param_keys)
        self.assertTrue("sharded_param" in param_keys)

        sharded_linear = MyShardedLinear(rank=self.rank).cuda(self.rank)
        sharded_linear.shard_parameter()
        sharded_linear_params = dict(named_params_with_sharded_tensor(sharded_linear))
        param_keys = list(sharded_linear_params.keys())
        self.assertEqual(len(param_keys), 4)
        self.assertTrue("linear1.bias" in param_keys)
        self.assertTrue("linear2.bias" in param_keys)
        self.assertTrue("linear1.weight" in param_keys)
        self.assertTrue("linear2.weight" in param_keys)
        self.assertFalse("bias" in param_keys)
Esempio n. 10
0
    def test_sharding_columns(self):
        self.init_pg()

        for dim in [1, -1]:
            spec = ChunkShardingSpec(
                dim=dim,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                ],
            )

            sharded_tensor = _sharded_tensor.empty(spec, 10, 32)

            # Validate local shard.
            local_shards = sharded_tensor.local_shards()
            self.assertEqual(1, len(local_shards))
            local_shard = local_shards[0].tensor
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.device)
            self.assertEqual((10, 8), local_shard.size())

            # Validate global metadata.
            sharding_metadata = sharded_tensor.sharding_metadata()
            self.assertEqual(4, len(sharding_metadata))

            for rank, shard_metadata in enumerate(sharding_metadata):
                self.assertEqual([0, rank * 8], shard_metadata.shard_offsets)
                self.assertEqual([10, 8], shard_metadata.shard_lengths)
                self.assertEqual(f'rank:{rank}/cuda:{rank}',
                                 shard_metadata.placement)
Esempio n. 11
0
    def test_insufficient_sharding_dims(self):
        self.init_pg()

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_tensor = _sharded_tensor.empty(spec, 2, 20)

        # Validate local shard.
        local_shards = sharded_tensor.local_shards()
        if self.rank <= 1:
            self.assertEqual(1, len(local_shards))
            local_shard = local_shards[0].tensor
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.device)
            self.assertEqual((1, 20), local_shard.size())
        else:
            self.assertEqual(0, len(local_shards))

        # Validate global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(2, len(sharding_metadata))

        for shard_rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual([shard_rank, 0], shard_metadata.shard_offsets)
            self.assertEqual([1, 20], shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_rank}/cuda:{shard_rank}',
                             shard_metadata.placement)
Esempio n. 12
0
    def test_sharded_tensor_sizes(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        # Test with *args
        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        # Test with single *args
        sharded_tensor = _sharded_tensor.empty(spec, 10)
        self.assertEqual(torch.Size([10]), sharded_tensor.size())

        # Test with list
        sharded_tensor = _sharded_tensor.empty(spec, [10, 20])
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        # Test with tuple
        sharded_tensor = _sharded_tensor.empty(spec, (10, 20))
        self.assertEqual(torch.Size([10, 20]), sharded_tensor.size())

        with self.assertRaises(TypeError):
            sharded_tensor = _sharded_tensor.empty(spec, 'foo')
Esempio n. 13
0
    def test_partial_world_size(self):
        self.init_pg()

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_tensor = _sharded_tensor.empty(spec, 10, 20)

        # Validate local shard.
        local_shards = sharded_tensor.local_shards()
        if self.rank >= 2:
            self.assertEqual(1, len(local_shards))
            local_shard = local_shards[0].tensor
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.device)
            self.assertEqual((5, 20), local_shard.size())
        else:
            self.assertEqual(0, len(local_shards))

        # Validate global metadata.
        sharding_metadata = sharded_tensor.sharding_metadata()
        self.assertEqual(2, len(sharding_metadata))

        for shard_rank, shard_metadata in enumerate(sharding_metadata):
            self.assertEqual([shard_rank * 5, 0], shard_metadata.shard_offsets)
            self.assertEqual([5, 20], shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_rank + 2}/cuda:{shard_rank + 2}',
                             shard_metadata.placement)
Esempio n. 14
0
    def test_init_sharded_tensor_with_normal(self):
        """ Test torch.nn.init.normal_(ShardedTensor, mean, std) """

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        h, w = 8, 2
        expected_h = 2
        expected_device = torch.device(f"cuda:{self.rank}")
        mean, std = 10, 5

        seed = 1234
        dtype = torch.double

        sharded_tensor = _sharded_tensor.empty(spec, h, w, dtype=dtype)
        self.assertEqual(1, len(sharded_tensor.local_shards()))

        # Clone local tensor to ensure torch.nn.init starts from the same input
        local_tensor_clone = torch.clone(sharded_tensor.local_shards()[0].tensor)
        torch.manual_seed(seed)
        torch.nn.init.normal_(sharded_tensor, mean=mean, std=std)

        torch.manual_seed(seed)
        torch.nn.init.normal_(local_tensor_clone, mean=mean, std=std)
        self.assertEqual(local_tensor_clone, sharded_tensor.local_shards()[0].tensor)
Esempio n. 15
0
    def test_new_group(self):

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:2",
                "rank:2/cuda:3",
            ],
        )

        pg = dist.new_group(ranks=[1, 2, 3])
        if self.rank >= 1:
            sharded_tensor = _sharded_tensor.empty(spec,
                                                   10,
                                                   20,
                                                   process_group=pg)

            # Validate local shard.
            local_shards = sharded_tensor.local_shards()
            if self.rank >= 2:
                self.assertEqual(1, len(local_shards))
                local_shard = local_shards[0].tensor
                self.assertEqual(torch.device(f"cuda:{self.rank}"),
                                 local_shard.device)
                self.assertEqual((5, 20), local_shard.size())
            else:
                self.assertEqual(0, len(local_shards))

            # Validate global metadata.
            sharded_tensor_metadata = sharded_tensor.metadata()
            shards_metadata = sharded_tensor_metadata.shards_metadata
            self.assertEqual(2, len(shards_metadata))

            for shard_rank, shard_metadata in enumerate(shards_metadata):
                self.assertEqual([shard_rank * 5, 0],
                                 shard_metadata.shard_offsets)
                self.assertEqual([5, 20], shard_metadata.shard_lengths)
                self.assertEqual(
                    f'rank:{shard_rank + 1}/cuda:{shard_rank + 2}',
                    shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            if self.rank >= 2:
                self.assertEqual(1, len(remote_shards))
            else:
                self.assertEqual(2, len(remote_shards))

            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))
                for remote_shard in shards:
                    shard = remote_shard.to_here()
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    self.assertEqual(f'rank:{rpc_rank - 1}/cuda:{rpc_rank}',
                                     shard.metadata.placement)
                    self.assertEqual((5, 20), shard.tensor.size())
Esempio n. 16
0
    def test_sharded_optim(self):
        rowwise_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        local_model = MyShardedModel().cuda(self.rank)
        sharded_model = MyShardedModel(spec=rowwise_spec).cuda(self.rank)

        # copy the parameteres from local model
        sharded_model.sharded_param.local_shards()[0].tensor = \
            local_model.sharded_param.detach().clone().requires_grad_()

        local_optim = optim.SGD(local_model.parameters(), lr=0.1)
        sharded_model_params = dict(named_params_with_sharded_tensor(sharded_model))
        sharded_optim = ShardedOptimizer(sharded_model_params, optim.SGD, lr=0.1)

        local_optim.zero_grad()
        sharded_optim.zero_grad()

        before_update = deepcopy(sharded_optim.named_params)

        inp = torch.rand([5, 10]).cuda(self.rank).requires_grad_()

        # run forward
        local_output = local_model(inp)
        sharded_output = sharded_model(inp)
        # backward
        local_output.sum().backward()
        sharded_output.sum().backward()

        # optimizer update
        local_optim.step()
        sharded_optim.step()

        # make sure the parameters (including sharded param)
        # get updated by the optimizer, and the updated
        # local params are the same as the sharded params
        for key, val in before_update.items():
            new_val = sharded_optim.named_params[key]
            if isinstance(val, sharded_tensor.ShardedTensor):
                self.assertNotEqual(
                    val.local_shards()[0].tensor,
                    new_val.local_shards()[0].tensor
                )
                self.assertEqual(
                    new_val.local_shards()[0].tensor,
                    local_model.sharded_param
                )
            else:
                self.assertNotEqual(val, new_val)
                self.assertEqual(new_val, local_model.param)
Esempio n. 17
0
    def test_complete_world_size(self):

        for dim in [0, -2]:
            spec = ChunkShardingSpec(
                dim=dim,
                placements=[
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                ],
            )
            sharded_tensor = _sharded_tensor.empty(spec, 10, 20)

            # Validate local shard.
            local_shards = sharded_tensor.local_shards()
            self.assertEqual(1, len(local_shards))
            local_shard = local_shards[0].tensor
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.device)
            if self.rank == 3:
                self.assertEqual((1, 20), local_shard.size())
            else:
                self.assertEqual((3, 20), local_shard.size())

            # Validate global metadata.
            sharded_tensor_metadata = sharded_tensor.metadata()
            shards_metadata = sharded_tensor_metadata.shards_metadata
            self.assertEqual(4, len(shards_metadata))

            for rank, shard_metadata in enumerate(shards_metadata):
                self.assertEqual([rank * 3, 0], shard_metadata.shard_offsets)
                if rank == 3:
                    self.assertEqual([1, 20], shard_metadata.shard_lengths)
                else:
                    self.assertEqual([3, 20], shard_metadata.shard_lengths)
                self.assertEqual(f'rank:{rank}/cuda:{rank}',
                                 shard_metadata.placement)

            # Validate remote shards.
            remote_shards = sharded_tensor.remote_shards
            self.assertEqual(3, len(remote_shards))

            for rpc_rank, shards in remote_shards.items():
                self.assertEqual(1, len(shards))
                for remote_shard in shards:
                    self.assertEqual(rpc_rank, remote_shard.owner().id)
                    shard = remote_shard.to_here()
                    self.assertEqual(f'rank:{rpc_rank}/cuda:{rpc_rank}',
                                     shard.metadata.placement)
                    if rpc_rank == 3:
                        self.assertEqual((1, 20), shard.tensor.size())
                    else:
                        self.assertEqual((3, 20), shard.tensor.size())
Esempio n. 18
0
    def shard_parameter(self):
        rowwise_sharding_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        colwise_sharding_spec = ChunkShardingSpec(
            dim=1,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        sharded_tensor.shard_parameter(self.linear1, "weight", rowwise_sharding_spec)
        sharded_tensor.shard_parameter(self.linear2, "weight", colwise_sharding_spec)
    def test_invalid_pg_rpc_ranks(self):
        self.init_pg()

        # Init RPC with different ranks.
        rpc_backend_options = rpc.TensorPipeRpcBackendOptions()
        rpc_backend_options.init_method = f"file://{self.file_name}"
        rank = (self.rank + 1) % self.world_size
        rpc.init_rpc(
            name=f'worker{rank}',
            rank=rank,
            world_size=self.world_size,
            rpc_backend_options=rpc_backend_options,
        )

        spec = ChunkShardingSpec(dim=0, placements=["rank:1/cuda:1"])
        with self.assertRaisesRegex(ValueError, 'Default ProcessGroup and RPC ranks must be the same'):
            _sharded_tensor.empty(spec, 10, 20)
Esempio n. 20
0
    def test_multiple_local_shards(self):
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )
        sharded_tensor = _sharded_tensor.empty(spec, 16, 20)

        # Validate local shards.
        local_shards = sharded_tensor.local_shards()
        self.assertEqual(2, len(local_shards))
        for local_shard in local_shards:
            self.assertEqual(torch.device(f"cuda:{self.rank}"),
                             local_shard.tensor.device)
            self.assertEqual((2, 20), local_shard.tensor.size())

        # Validate global metadata.
        sharded_tensor_metadata = sharded_tensor.metadata()
        shards_metadata = sharded_tensor_metadata.shards_metadata
        self.assertEqual(8, len(shards_metadata))

        for shard_idx, shard_metadata in enumerate(shards_metadata):
            self.assertEqual([shard_idx * 2, 0], shard_metadata.shard_offsets)
            self.assertEqual([2, 20], shard_metadata.shard_lengths)
            self.assertEqual(f'rank:{shard_idx % 4}/cuda:{shard_idx % 4}',
                             shard_metadata.placement)

        # Validate remote shards.
        remote_shards = sharded_tensor.remote_shards
        self.assertEqual(3, len(remote_shards))
        owners = {}
        for rpc_rank, shards in remote_shards.items():
            self.assertEqual(2, len(shards))
            for remote_shard in shards:
                shard = remote_shard.to_here()
                self.assertEqual((2, 20), shard.tensor.size())
                self.assertEqual(rpc_rank, remote_shard.owner().id)
Esempio n. 21
0
    def test_chunked_sharding_spec(self):
        # Test valid specs.
        ChunkShardingSpec(0, [0, 1])
        # Named dimension.
        ChunkShardingSpec("N", ["cuda:0", "cuda:1"])
        ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")])
        ChunkShardingSpec(-1, ["cuda:0", "cuda:1"])
        ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"])
        ChunkShardingSpec(0, ["rank:0", "rank:1"])
        ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"])

        # Test invalid specs
        with self.assertRaisesRegex(ValueError, "int or str"):
            ChunkShardingSpec(None, ["cuda:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "int or str"):
            ChunkShardingSpec({}, ["cuda:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["random:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["cuda:foo", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["rank:foo", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "not a valid device"):
            ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])
Esempio n. 22
0
    def test_torch_equal(self):
        """ Test torch.equal(ShardedTensor, ShardedTensor) """

        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
                "rank:2/cuda:2",
                "rank:3/cuda:3",
            ],
        )

        alt_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:1",
                "rank:0/cuda:0",
                "rank:3/cuda:3",
                "rank:2/cuda:2",
            ],
        )

        st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
        self.assertTrue(torch.equal(st1, st2))

        st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
        if self.rank == 0:
            torch.nn.init.uniform_(st1.local_shards()[0].tensor)
        self.assertFalse(torch.equal(st1, st2))

        st1 = _sharded_tensor.ones(spec, 10, 10)
        st2 = _sharded_tensor.ones(spec, 10, 5)
        self.assertFalse(torch.equal(st1, st2))

        st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10)
        self.assertFalse(torch.equal(st1, st2))

        st1 = _sharded_tensor.ones(spec, 10, 10)
        st2 = _sharded_tensor.zeros(spec, 10, 10)
        self.assertFalse(torch.equal(st1, st2))

        st1 = _sharded_tensor.ones(spec, 10, 10)
        st2 = _sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
        self.assertFalse(torch.equal(st1, st2))

        st1 = _sharded_tensor.ones(spec, 10, 10)
        st2 = _sharded_tensor.ones(spec, 10, 10, requires_grad=True)
        self.assertFalse(torch.equal(st1, st2))

        cpu_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cpu",
                "rank:1/cpu",
                "rank:2/cpu",
                "rank:3/cpu",
            ],
        )
        st1 = _sharded_tensor.ones(cpu_spec, 10, 10)
        st2 = _sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
        self.assertFalse(torch.equal(st1, st2))

        pg = dist.new_group([1, 0, 3, 2])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        self.assertFalse(torch.equal(st1, st2))

        pg = dist.new_group([0, 1, 2, 3])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        self.assertFalse(torch.equal(st1, st2))
Esempio n. 23
0
    def test_sharded_linear_errors(self):
        for spec in generate_chunk_sharding_specs_for_test(0):
            fc1 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc1, "bias", spec)
            with self.assertRaisesRegex(
                    TypeError, 'input and bias need to be torch.Tensor'):
                fc1(torch.rand(10, 10).cuda(self.rank))

            fc2 = torch.nn.Linear(10, 10).cuda(self.rank)
            shard_parameter(fc2, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Input needs to have at least 1 dim'):
                fc2(torch.tensor(1).cuda(self.rank))

            fc3 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc3.weight = torch.nn.Parameter(
                torch.rand(10, 10, 10).cuda(self.rank))
            shard_parameter(fc3, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Weight needs to have exactly 2 dims'):
                fc3(torch.rand(10, 10).cuda(self.rank))

            fc4 = torch.nn.Linear(10, 10).cuda(self.rank)
            fc4.bias = torch.nn.Parameter(torch.rand(10, 10).cuda(self.rank))
            shard_parameter(fc4, "weight", spec)
            with self.assertRaisesRegex(ValueError,
                                        'Bias needs to have exactly 1 dim'):
                fc4(torch.rand(10, 10).cuda(self.rank))

            fc5 = torch.nn.Linear(7, 10).cuda(self.rank)
            shard_parameter(fc5, "weight", spec)
            with self.assertRaisesRegex(
                    ValueError,
                    'Input dim: 13 does not match appropriate weight dim: 7'):
                fc5(torch.rand(20, 10, 13).cuda(self.rank))

            fc6 = torch.nn.Linear(10, 10).cuda(self.rank)
            del fc6.weight
            enumerable_spec = EnumerableShardingSpec([
                ShardMetadata(
                    shard_offsets=[0, 0],
                    shard_sizes=[5, 5],
                    placement="rank:0/cuda:0",
                ),
                ShardMetadata(
                    shard_offsets=[0, 5],
                    shard_sizes=[5, 5],
                    placement="rank:1/cuda:1",
                ),
                ShardMetadata(
                    shard_offsets=[5, 0],
                    shard_sizes=[5, 5],
                    placement="rank:2/cuda:2",
                ),
                ShardMetadata(
                    shard_offsets=[5, 5],
                    shard_sizes=[5, 5],
                    placement="rank:3/cuda:3",
                )
            ])

            fc6.weight = empty(enumerable_spec, 10, 10)
            with self.assertRaisesRegex(
                    ValueError,
                    'Only ChunkShardingSpec supported for ShardedTensor ops!'):
                fc6(torch.rand(10, 10).cuda(self.rank))

            fc7 = torch.nn.Linear(10, 80).cuda(self.rank)
            multiple_local_shard_spec = ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0/cuda:0",
                    "rank:0/cuda:0",
                    "rank:1/cuda:1",
                    "rank:1/cuda:1",
                    "rank:2/cuda:2",
                    "rank:2/cuda:2",
                    "rank:3/cuda:3",
                    "rank:3/cuda:3",
                ],
            )
            del fc7.weight
            fc7.weight = empty(multiple_local_shard_spec, 80, 10)
            with self.assertRaisesRegex(ValueError,
                                        'Only one local shard supported!'):
                fc7(torch.rand(10, 10).cuda(self.rank))
Esempio n. 24
0
    def test_chunked_sharding_spec(self):
        # Test valid specs.
        ChunkShardingSpec(0, [torch.device(0), torch.device(1)])
        # Named dimension.
        ChunkShardingSpec("N", ["cuda:0", "cuda:1"])
        ChunkShardingSpec(0, [torch.device("cuda:0"), torch.device("cuda:1")])
        ChunkShardingSpec(-1, ["cuda:0", "cuda:1"])
        ChunkShardingSpec(0, ["rank:0/cuda:0", "rank:0/cuda:1"])
        ChunkShardingSpec(0, ["rank:0", "rank:1"])
        ChunkShardingSpec(0, ["rank:0/cpu", "rank:1/cpu"])

        # Test invalid specs
        with self.assertRaisesRegex(ValueError, "int or str"):
            ChunkShardingSpec(None, ["cuda:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError, "int or str"):
            ChunkShardingSpec({}, ["cuda:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError,
                                    "Could not parse remote_device"):
            ChunkShardingSpec(0, ["random:0", "cuda:1"])
        with self.assertRaisesRegex(ValueError,
                                    "Could not parse remote_device"):
            ChunkShardingSpec(0, ["cuda:foo", "cuda:1"])
        with self.assertRaisesRegex(ValueError,
                                    "Could not parse remote_device"):
            ChunkShardingSpec(0, ["rank:foo", "cuda:1"])
        with self.assertRaisesRegex(RuntimeError, "Expected one of"):
            ChunkShardingSpec(0, ["rank:0/foo", "cuda:1"])
        with self.assertRaisesRegex(RuntimeError, "Expected one of"):
            ChunkShardingSpec(0, ["rank:0/random:0", "cuda:1"])
        with self.assertRaisesRegex(RuntimeError, "Invalid device string"):
            ChunkShardingSpec(0, ["rank:0/cuda:foo", "cuda:1"])