def test_validate_metadata(self) -> None: module = TestModule() metadata, _, _ = _prepare(module.state_dict()) self.assertTrue( "regular" in metadata.state_dict_metadata, f"keys: {metadata.state_dict_metadata.keys()}", ) module = TestModule() validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_param = torch.nn.Parameter(torch.zeros(2, 2)) with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.regular = torch.nn.Parameter(torch.zeros(2, 4)) with self.assertRaisesRegex(ValueError, "Incompatible tensor size"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"): validate_metadata(module.state_dict(), metadata)
def _load(): metadata, _, _ = _prepare(state_dict, write_replicated) load_state_dict( state_dict, storage_reader=FaultyStorageReader(metadata, kwargs), coordinator_rank=coordinator, no_dist=no_dist, )
def gen_metadata(self) -> Metadata: module = TestModule() # compute the default saved metadata (must pass include_non_replicated_tensors or we'll get incomplete MD) metadata, _, _ = _prepare(module.state_dict(), True) # _prepare only produc metadata = [metadata] dist.broadcast_object_list(metadata) return metadata[0]
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, ( 10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare( state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertEqual(bytes_reqs[0].storage_key, metadata.state_dict_metadata['bytes'].storage_key) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) self.assertEqual( replicated.storage_key, metadata.state_dict_metadata['replicated'].storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] self.assertTrue('sharded' in metadata.state_dict_metadata) shard_keys = [ sm.storage_key for sm in metadata.state_dict_metadata['sharded'].storage_metadata ] self.assertTrue(shard.storage_key in shard_keys)
def test_storage_key_mapping(self) -> None: device = f"cuda:{dist.get_rank()}" spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cuda:0", "rank:1/cuda:1", ], ) state_dict = { 'sharded': sharded_tensor.rand(spec, (10, 10, )), 'replicated': torch.rand(4, device=device), 'bytes': [1, 2, 3, 4], } metadata, bytes_reqs, tensor_reqs = _prepare(state_dict, write_replicated_data=self.rank == 0) if self.rank == 0: self.assertEqual(1, len(bytes_reqs)) self.assertEqual(2, len(tensor_reqs)) self.assertTrue('bytes' in metadata.state_dict_metadata) self.assertTrue(MetadataIndex('bytes') in metadata.storage_data) # tensor ordering is unspecified if len(tensor_reqs[0].tensor.size()) == 1: replicated = tensor_reqs[0] shard = tensor_reqs[1] else: replicated = tensor_reqs[1] shard = tensor_reqs[0] self.assertTrue('replicated' in metadata.state_dict_metadata) storage_key = MetadataIndex('replicated', torch.Size([0])) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], replicated.storage_key) else: self.assertEqual(0, len(bytes_reqs)) self.assertEqual(1, len(tensor_reqs)) shard = tensor_reqs[0] local_shard = state_dict["sharded"].local_shards()[0] self.assertTrue('sharded' in metadata.state_dict_metadata) storage_key = MetadataIndex('sharded', torch.Size(local_shard.metadata.shard_offsets)) self.assertTrue(storage_key in metadata.storage_data) self.assertTrue(metadata.storage_data[storage_key], shard.storage_key)