def test_validate_metadata(self) -> None: module = TestModule() metadata, _, _ = _prepare(module.state_dict()) self.assertTrue( "regular" in metadata.state_dict_metadata, f"keys: {metadata.state_dict_metadata.keys()}", ) module = TestModule() validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_param = torch.nn.Parameter(torch.zeros(2, 2)) with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.regular = torch.nn.Parameter(torch.zeros(2, 4)) with self.assertRaisesRegex(ValueError, "Incompatible tensor size"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"): validate_metadata(module.state_dict(), metadata)
def test_checkpoint_has_storage_type_mismatch(self) -> None: module = TestModule() metadata = self.gen_metadata() regular = metadata.state_dict_metadata["regular"] metadata.state_dict_metadata[".sharded"] = regular with self.assertRaisesRegex(ValueError, "ShardedTensorStorageMetadata but found"): validate_metadata(module.state_dict(), metadata) metadata = self.gen_metadata() sharded = metadata.state_dict_metadata[".sharded"] metadata.state_dict_metadata["regular"] = sharded with self.assertRaisesRegex(ValueError, "TensorStorageMetadata but found"): validate_metadata(module.state_dict(), metadata)
def test_checkpoint_has_shard_overlap(self) -> None: metadata = self.gen_metadata() # we make the first stored shard smaller self.assertTrue( "sharded" in metadata.state_dict_metadata, f"keys: {metadata.state_dict_metadata.keys()}", ) sizes = (metadata.state_dict_metadata["sharded"].storage_metadata[0]. shard_metadata.shard_sizes) for i in range(len(sizes)): sizes[i] += 1 module = TestModule() with self.assertRaisesRegex(ValueError, "overlap"): validate_metadata(module.state_dict(), metadata)