def _load(): metadata, _, _ = _prepare(state_dict, write_replicated) load_state_dict( state_dict, storage_reader=FaultyStorageReader(metadata, kwargs), coordinator_rank=coordinator, no_dist=no_dist, )
def test_load_rowwise_to_colwise(self) -> None: path = self.get_file_path() self.assertEqual(self.world_size, dist.get_world_size()) # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. src_spec = ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ) # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. dst_spec = ChunkShardingSpec( dim=1, placements=[ "rank:0", "rank:1", ], ) if dist.get_rank() == 0: shutil.rmtree(path, ignore_errors=True) os.makedirs(path) model_to_save = MyShardedModel3(src_spec).cuda(dist.get_rank()) model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) model_to_load = MyShardedModel3(dst_spec).cuda(dist.get_rank()) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) # We can't use torch.allclose since each ST has a different sharding spec store_tensor = self.load_tensor(model_to_save.sharded_tensor) load_tensor = self.load_tensor(model_to_load.sharded_tensor) if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor))
def test_read_write_shard_tensor(self) -> None: paths = [tempfile.mkdtemp()] dist.broadcast_object_list(paths) path = paths[0] # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. spec = ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ) model_to_save = MyShardedModel1(spec, init_rrefs=False) # Test save model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) dist.barrier() # Create a new model model_to_load = MyShardedModel1(spec, init_rrefs=False) # This is not the correct hook for loading the state dict # model_to_load._register_load_state_dict_pre_hook(pre_load_state_dict_hook, True) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() dist.barrier() with self.assertRaises(AssertionError): assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) # Test load. fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) dist.barrier()
def test_save_load_bytes(self) -> None: path = self.get_file_path() state_dict_to_save = {'bytes0': [1], 'bytes1': 'string'} fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) state_dict_to_load = {'bytes0': [2], 'bytes1': 'other'} fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load, storage_reader=fs_reader) self.assertEqual([1], state_dict_to_load['bytes0']) self.assertEqual('string', state_dict_to_load['bytes1'])
def test_read_write_only_tensor(self) -> None: with tempfile.TemporaryDirectory() as path: state_dict_to_save = MyTestModule().state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer, no_dist=True) state_dict_to_load_to = MyTestModule().state_dict() with self.assertRaises(AssertionError): assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save) # Load from file without any resharding fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader, no_dist=True) assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)
def test_switch_between_sharded_tensor_to_tensor(self) -> None: path = self.get_file_path() tensor_size = 32 specs = [ ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[8], placement="rank:1", ), ShardMetadata( shard_offsets=[8], shard_sizes=[tensor_size - 8], placement="rank:0", ), ]), EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0], shard_sizes=[10], placement="rank:0", ), ShardMetadata( shard_offsets=[10], shard_sizes=[tensor_size - 10], placement="rank:1", ), ]), ] for save_spec in specs: for load_spec in specs: save_dict = { 'sharded': sharded_tensor.rand(save_spec, tensor_size), 'replicated': torch.rand(tensor_size, device=f"cpu:{self.rank}") } fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=save_dict, storage_writer=fs_writer) # Freaky Friday the tensors load_dict = { 'sharded': torch.zeros(tensor_size, device=f"cpu:{self.rank}"), 'replicated': sharded_tensor.zeros(load_spec, tensor_size) } fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=load_dict, storage_reader=fs_reader) save_dict_sharded = self.load_tensor(save_dict['sharded']) load_dict_replicated = self.load_tensor( load_dict['replicated']) if dist.get_rank() == 0: self.assertTrue( torch.allclose(save_dict_sharded, load_dict['sharded']), f"save-spec {save_spec} load-spec {load_spec}") self.assertTrue( torch.allclose(save_dict['replicated'], load_dict_replicated), f"save-spec {save_spec} load-spec {load_spec}")
def test_load_with_different_shard_plan(self) -> None: path = self.get_file_path() # We hardcode the assumption of how many shards are around self.assertEqual(self.world_size, dist.get_world_size()) specs = [ # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", ], ), # pyre-fixme [28]: Unexpected keyword argument `dim` to call `dist._sharding_spec.api.ChunkShardingSpec.__init__`. ChunkShardingSpec( dim=0, placements=[ "rank:0", "rank:1", "rank:1", "rank:0", ], ), # This requires the tensors to be [10, 20] EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[2, 20], placement="rank:0", ), ShardMetadata( shard_offsets=[2, 0], shard_sizes=[1, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[3, 0], shard_sizes=[3, 20], placement="rank:0", ), ShardMetadata( shard_offsets=[6, 0], shard_sizes=[3, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[9, 0], shard_sizes=[1, 20], placement="rank:0", ), ]), # This requires the tensors to be [10, 20] EnumerableShardingSpec(shards=[ ShardMetadata( shard_offsets=[0, 0], shard_sizes=[8, 20], placement="rank:1", ), ShardMetadata( shard_offsets=[8, 0], shard_sizes=[2, 20], placement="rank:0", ), ]), ] for s0 in specs: for s1 in specs: if s0 == s1: continue if dist.get_rank() == 0: shutil.rmtree(path, ignore_errors=True) os.makedirs(path) dist.barrier() model_to_save = MyShardedModel3(s0) model_to_save._register_state_dict_hook(state_dict_hook) state_dict_to_save = model_to_save.state_dict() fs_writer = FileSystemWriter(path=path) save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer) dist.barrier() model_to_load = MyShardedModel3(s1) model_to_load._register_state_dict_hook(state_dict_hook) state_dict_to_load_to = model_to_load.state_dict() dist.barrier() fs_reader = FileSystemReader(path=path) load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader) dist.barrier() store_tensor = self.load_tensor(model_to_save.sharded_tensor) dist.barrier() load_tensor = self.load_tensor(model_to_load.sharded_tensor) if dist.get_rank() == 0: self.assertTrue(torch.allclose(store_tensor, load_tensor), msg=f"{s0} vs {s1}")