def test_validate_metadata(self) -> None:
        module = TestModule()

        metadata, _, _ = _prepare(module.state_dict())
        self.assertTrue(
            "regular" in metadata.state_dict_metadata,
            f"keys: {metadata.state_dict_metadata.keys()}",
        )

        module = TestModule()
        validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_param = torch.nn.Parameter(torch.zeros(2, 2))
        with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.regular = torch.nn.Parameter(torch.zeros(2, 4))

        with self.assertRaisesRegex(ValueError, "Incompatible tensor size"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"):
            validate_metadata(module.state_dict(), metadata)

        module = TestModule()
        module.sharded = sharded_tensor.zeros(module.spec(), 4, 2)
        with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"):
            validate_metadata(module.state_dict(), metadata)
    def test_checkpoint_has_storage_type_mismatch(self) -> None:
        module = TestModule()

        metadata = self.gen_metadata()
        regular = metadata.state_dict_metadata["regular"]
        metadata.state_dict_metadata[".sharded"] = regular
        with self.assertRaisesRegex(ValueError, "ShardedTensorStorageMetadata but found"):
            validate_metadata(module.state_dict(), metadata)

        metadata = self.gen_metadata()
        sharded = metadata.state_dict_metadata[".sharded"]
        metadata.state_dict_metadata["regular"] = sharded
        with self.assertRaisesRegex(ValueError, "TensorStorageMetadata but found"):
            validate_metadata(module.state_dict(), metadata)
Exemple #3
0
    def test_checkpoint_has_shard_overlap(self) -> None:
        metadata = self.gen_metadata()

        # we make the first stored shard smaller
        self.assertTrue(
            "sharded" in metadata.state_dict_metadata,
            f"keys: {metadata.state_dict_metadata.keys()}",
        )

        sizes = (metadata.state_dict_metadata["sharded"].storage_metadata[0].
                 shard_metadata.shard_sizes)
        for i in range(len(sizes)):
            sizes[i] += 1

        module = TestModule()
        with self.assertRaisesRegex(ValueError, "overlap"):
            validate_metadata(module.state_dict(), metadata)