def test_validate_metadata(self) -> None: module = TestModule() metadata, _, _ = _prepare(module.state_dict()) self.assertTrue( "regular" in metadata.state_dict_metadata, f"keys: {metadata.state_dict_metadata.keys()}", ) module = TestModule() validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_param = torch.nn.Parameter(torch.zeros(2, 2)) with self.assertRaisesRegex(ValueError, "Could not find Tensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.regular = torch.nn.Parameter(torch.zeros(2, 4)) with self.assertRaisesRegex(ValueError, "Incompatible tensor size"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.extra_sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Could not find ShardedTensor metadata"): validate_metadata(module.state_dict(), metadata) module = TestModule() module.sharded = sharded_tensor.zeros(module.spec(), 4, 2) with self.assertRaisesRegex(ValueError, "Incompatible ShardedTensor size"): validate_metadata(module.state_dict(), metadata)
def __init__(self) -> None: super().__init__() self.sharded: ShardedTensor = sharded_tensor.zeros(self.spec(), 4, 4) self.regular = torch.nn.Parameter(torch.ones(4, 4)) self.extra_sharded: Optional[ShardedTensor] = None self.extra_param: Optional[torch.nn.Parameter] = None self._register_state_dict_hook(state_dict_hook)
def _test_common_failures(self, cmp_op): spec, alt_spec = self.get_gpu_specs() st1, st2 = self.get_random_tensors(spec, spec, 10, 10) if self.rank == 0: torch.nn.init.uniform_(st1.local_shards()[0].tensor) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 5) self.assertFalse(cmp_op(st1, st2)) st1, st2 = self.get_random_tensors(spec, alt_spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.zeros(spec, 10, 10) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 10, dtype=torch.double) self.assertFalse(cmp_op(st1, st2)) st1 = sharded_tensor.ones(spec, 10, 10) st2 = sharded_tensor.ones(spec, 10, 10, requires_grad=True) self.assertFalse(cmp_op(st1, st2)) cpu_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0/cpu", "rank:1/cpu", "rank:2/cpu", "rank:3/cpu", ], ) st1 = sharded_tensor.ones(cpu_spec, 10, 10) st2 = sharded_tensor.ones(cpu_spec, 10, 10, pin_memory=True) self.assertFalse(cmp_op(st1, st2)) pg = dist.new_group([1, 0, 3, 2]) st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg) with self.assertRaisesRegex( RuntimeError, "All distributed tensors should use the same ProcessGroup"): cmp_op(st1, st2) pg = dist.new_group([0, 1, 2, 3]) st1, st2 = self.get_random_tensors(spec, spec, 10, 10, pg2=pg) with self.assertRaisesRegex( RuntimeError, "All distributed tensors should use the same ProcessGroup"): cmp_op(st1, st2)
def test_tensor_metadata_with_missing_rank_spec(self) -> None: spec = ChunkShardingSpec( dim=0, placements=[ "rank:1/cuda:1", ], ) st = sharded_tensor.zeros(spec, 4, 4, dtype=torch.float64) mapping = dict() (_, md, storage_md) = _prepare_sharded_tensor_write("fqn", st, "tensor", mapping) self.assertEqual(1, len(storage_md)) self.assertEqual(1, len(mapping))
def test_switch_between_sharded_tensor_to_tensor(self) -> None: path = self.get_file_path() tensor_size = 32 specs = [ ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[8], placement="rank:1", ), ShardMetadata( shard_offsets=[8], shard_sizes=[tensor_size - 8], placement="rank:0", ), ]), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[10], placement="rank:0", ), ShardMetadata( shard_offsets=[10], shard_sizes=[tensor_size - 10], placement="rank:1", ), ]), ] for save_spec in specs: for load_spec in specs: save_dict = { 'sharded': sharded_tensor.rand(save_spec, tensor_size), 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") } fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=save_dict, storage_writer=fs_writer) # Freaky Friday the tensors load_dict = { 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), 'replicated': sharded_tensor.zeros(load_spec, tensor_size) } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) save_dict_sharded = self.load_tensor(save_dict['sharded']) load_dict_replicated = self.load_tensor( load_dict['replicated']) if dist.get_rank() == 0: self.assertTrue( torch.allclose(save_dict_sharded, load_dict['sharded']), f"save-spec {save_spec} load-spec {load_spec}") self.assertTrue( torch.allclose(save_dict['replicated'], load_dict_replicated), f"save-spec {save_spec} load-spec {load_spec}")