def test_validate_metadata(self) -> None:
        module = TestModule()

        metadata, _, _ = _prepare(module.state_dict())
        self.assertTrue(
            "regular" in metadata.state_dict_metadata,
            f"keys: {metadata.state_dict_metadata.keys()}",
        )

        module = TestModule()
        validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_param = torch.nn.Parameter(torch.zeros(2, 2))
        with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.regular = torch.nn.Parameter(torch.zeros(2, 4))

        with self.assertRaisesRegex(ValueError, "Incompatible tensor size"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"):
            validate_metadata(module.state_dict(), metadata)
Beispiel #2
0
 def _load():
     metadata, _, _ = _prepare(state_dict, write_replicated)
     load_state_dict(
         state_dict,
         storage_reader=FaultyStorageReader(metadata, kwargs),
         coordinator_rank=coordinator,
         no_dist=no_dist,
     )
Beispiel #3
0
    def gen_metadata(self) -> Metadata:
        module = TestModule()
        # compute the default saved metadata (must pass include_non_replicated_tensors or we'll get incomplete MD)
        metadata, _, _ = _prepare(module.state_dict(), True)

        # _prepare only produc
        metadata = [metadata]
        dist.broadcast_object_list(metadata)

        return metadata[0]
Beispiel #4
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (
                10,
                10,
            )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(
            state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertEqual(bytes_reqs[0].storage_key,
                             metadata.state_dict_metadata['bytes'].storage_key)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            self.assertEqual(
                replicated.storage_key,
                metadata.state_dict_metadata['replicated'].storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            shard_keys = [
                sm.storage_key for sm in
                metadata.state_dict_metadata['sharded'].storage_metadata
            ]
            self.assertTrue(shard.storage_key in shard_keys)
Beispiel #5
0
    def test_storage_key_mapping(self) -> None:
        device = f"cuda:{dist.get_rank()}"
        spec = ChunkShardingSpec(
            dim=0,
            placements=[
                "rank:0/cuda:0",
                "rank:1/cuda:1",
            ],
        )

        state_dict = {
            'sharded': sharded_tensor.rand(spec, (10, 10, )),
            'replicated': torch.rand(4, device=device),
            'bytes': [1, 2, 3, 4],
        }

        metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0)

        if self.rank == 0:
            self.assertEqual(1, len(bytes_reqs))
            self.assertEqual(2, len(tensor_reqs))

            self.assertTrue('bytes' in metadata.state_dict_metadata)
            self.assertTrue(MetadataIndex('bytes') in metadata.storage_data)

            # tensor ordering is unspecified
            if len(tensor_reqs[0].tensor.size()) == 1:
                replicated = tensor_reqs[0]
                shard = tensor_reqs[1]
            else:
                replicated = tensor_reqs[1]
                shard = tensor_reqs[0]

            self.assertTrue('replicated' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('replicated', torch.Size([0]))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key)
        else:
            self.assertEqual(0, len(bytes_reqs))
            self.assertEqual(1, len(tensor_reqs))
            shard = tensor_reqs[0]
            local_shard = state_dict["sharded"].local_shards()[0]

            self.assertTrue('sharded' in metadata.state_dict_metadata)
            storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets))
            self.assertTrue(storage_key in metadata.storage_data)
            self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)