コード例 #1
0
    def test_validate_metadata(self) -> None:
        module = TestModule()

        metadata, _, _ = _prepare(module.state_dict())
        self.assertTrue(
            "regular" in metadata.state_dict_metadata,
            f"keys: {metadata.state_dict_metadata.keys()}",
        )

        module = TestModule()
        validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_param = torch.nn.Parameter(torch.zeros(2, 2))
        with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.regular = torch.nn.Parameter(torch.zeros(2, 4))

        with self.assertRaisesRegex(ValueError, "Incompatible tensor size"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"):
            validate_metadata(module.state_dict(), metadata)
コード例 #2
0
 def __init__(self) -> None:
     super().__init__()
     self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4)
     self.regular = torch.nn.Parameter(torch.ones(4, 4))
     self.extra_sharded: Optional[ShardedTensor] = None
     self.extra_param: Optional[torch.nn.Parameter] = None
     self._register_state_dict_hook(state_dict_hook)
コード例 #3
0
    def _test_common_failures(self, cmp_op):
        spec, alt_spec = self.get_gpu_specs()

        st1, st2 = self.get_random_tensors(spec, spec, 10, 10)
        if self.rank == 0:
            torch.nn.init.uniform_(st1.local_shards()[0].tensor)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 5)
        self.assertFalse(cmp_op(st1, st2))

        st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.zeros(spec, 10, 10)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double)
        self.assertFalse(cmp_op(st1, st2))

        st1 = sharded_tensor.ones(spec, 10, 10)
        st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True)
        self.assertFalse(cmp_op(st1, st2))

        cpu_spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cpu",
                "rank:1/cpu",
                "rank:2/cpu",
                "rank:3/cpu",
            ],
        )
        st1 = sharded_tensor.ones(cpu_spec, 10, 10)
        st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True)
        self.assertFalse(cmp_op(st1, st2))

        pg = dist.new_group([1, 0, 3, 2])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        with self.assertRaisesRegex(
                RuntimeError,
                "All distributed tensors should use the same ProcessGroup"):
            cmp_op(st1, st2)

        pg = dist.new_group([0, 1, 2, 3])
        st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg)
        with self.assertRaisesRegex(
                RuntimeError,
                "All distributed tensors should use the same ProcessGroup"):
            cmp_op(st1, st2)
コード例 #4
0
    def test_tensor_metadata_with_missing_rank_spec(self) -> None:
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:1/cuda:1",
            ],
        )

        st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64)
        mapping = dict()

        (_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping)

        self.assertEqual(1, len(storage_md))
        self.assertEqual(1, len(mapping))
コード例 #5
0
    def test_switch_between_sharded_tensor_to_tensor(self) -> None:
        path = self.get_file_path()
        tensor_size = 32

        specs = [
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                ],
            ),
            ChunkShardingSpec(
                dim=0,
                placements=[
                    "rank:0",
                    "rank:1",
                    "rank:1",
                    "rank:0",
                ],
            ),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[8],
                    placement="rank:1",
                ),
                ShardMetadata(
                    shard_offsets=[8],
                    shard_sizes=[tensor_size - 8],
                    placement="rank:0",
                ),
            ]),
            EnumerableShardingSpec(shards=[
                ShardMetadata(
                    shard_offsets=[0],
                    shard_sizes=[10],
                    placement="rank:0",
                ),
                ShardMetadata(
                    shard_offsets=[10],
                    shard_sizes=[tensor_size - 10],
                    placement="rank:1",
                ),
            ]),
        ]

        for save_spec in specs:
            for load_spec in specs:
                save_dict = {
                    'sharded':
                    sharded_tensor.rand(save_spec, tensor_size),
                    'replicated':
                    torch.rand(tensor_size, device=f"cpu:{self.rank}")
                }

                fs_writer = FileSystemWriter(path=path)
                save_state_dict(state_dict=save_dict, storage_writer=fs_writer)

                # Freaky Friday the tensors
                load_dict = {
                    'sharded': torch.zeros(tensor_size,
                                           device=f"cpu:{self.rank}"),
                    'replicated': sharded_tensor.zeros(load_spec, tensor_size)
                }

                fs_reader = FileSystemReader(path=path)
                load_state_dict(state_dict=load_dict, storage_reader=fs_reader)

                save_dict_sharded = self.load_tensor(save_dict['sharded'])
                load_dict_replicated = self.load_tensor(
                    load_dict['replicated'])

                if dist.get_rank() == 0:
                    self.assertTrue(
                        torch.allclose(save_dict_sharded,
                                       load_dict['sharded']),
                        f"save-spec {save_spec} load-spec {load_spec}")
                    self.assertTrue(
                        torch.allclose(save_dict['replicated'],
                                       load_dict_replicated),
                        f"save-spec {save_spec} load-spec {load_spec}")