def test_basic_save_and_load_state_dict(self, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() full_state_dict_mgr = self._get_full_state_dict_mgr( model, state_dict_rank0_and_offload) with full_state_dict_mgr: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def _compare_models(self, model, model_new, assert_fn, check_fp16=False): with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) assert_fn(params, params_new) if check_fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_state_dict_load_into_local_module(self): """ Tests that FSDP's state_dict can be loaded into a local model. """ model = self._initialize_model(wrap_fsdp=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return state_dict. fsdp_state_dict = model.state_dict() # Create zeroed local model blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() # Load fsdp's full state dict into the local and verify params are as # expected. blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_rank0_offload_save_load_flow(self): # Test taking checkpoint on rank 0 only, and reload # without redundant CPU memories. model = TransformerWithSharedParams( group=dist.distributed_c10d._get_default_group()) my_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }) model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) ctx = self._get_state_dict_mgr(model, "state_dict", True) with ctx: state_dict = deepcopy(_get_state_dict(model)) # All ranks initialize non-FSDP model grp = dist.distributed_c10d._get_default_group() model_new = TransformerWithSharedParams(group=grp) for p in model_new.parameters(): with torch.no_grad(): p.zero_() # Only rank 0 loads the checkpoint if self.rank == 0: model_new.load_state_dict(state_dict) # TransformerWithSharedParams has a buffer of zeros, so can't pass in # self.assertNotEqual since the buffers would be equal. So just checking that # there is some difference in the model across ranks before state_dict is # broadcasted. with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) # FSDP with sync_module_states=True broadcasts the checkpointed states. model_new = FSDP(model_new, device_id=torch.cuda.current_device(), auto_wrap_policy=my_auto_wrap_policy, sync_module_states=True) # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint with FSDP.summon_full_params(model_new): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new)
def test_basic_save_and_load_state_dict(self, cpu_offload, fp16): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_state_dict_load_into_local_module( self, state_dict_type, state_dict_rank0_and_offload ): """ Tests that FSDP's state_dict can be loaded into a local model. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_state_dict_mgr( model, state_dict_type, state_dict_rank0_and_offload ) with sd_mgr: fsdp_state_dict = model.state_dict() self._validate_state_dict_contents( fsdp_state_dict, state_dict_rank0_and_offload ) # Create zeroed local model blank_local_model = self._initialize_model( wrap_fsdp=False, wrap_ddp=False, register_buffers=True, ) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() fsdp_state_dict = _gather_state_dict(fsdp_state_dict) # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)
def test_summon_from_non_fsdp(self): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2, fsdp_3): super().__init__() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 self.fsdp_3 = fsdp_3 model_fsdp = FSDPContainer( FSDP(DeterministicModel(wrap_fsdp=True)), FSDP(DeterministicModel(wrap_fsdp=True)), DeterministicModel(wrap_fsdp=False), ) model_no_fsdp = FSDPContainer( DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), ) params_to_compare = list(model_no_fsdp.parameters()) with FullyShardedDataParallel.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params)
def test_state_dict_load_into_local_module( self, state_dict_type, state_dict_rank0_and_offload, fsdp_root, ): """ Tests that FSDP's state_dict can be loaded into a local model. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return if not fsdp_root: model = self._get_non_fsdp_root_module() else: model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) if not fsdp_root: in_data = torch.randn(1, 10, requires_grad=True, device=torch.device("cuda")) else: in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with sd_mgr: fsdp_state_dict = model.state_dict() ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) # Create zeroed local model if not fsdp_root: blank_local_model = self._get_non_fsdp_root_module(wrap=False) else: blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False, register_buffers=True) # Nothing should be FSDP for mod in blank_local_model.modules(): self.assertFalse(isinstance(mod, FSDP)) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() fsdp_state_dict = _gather_state_dict(fsdp_state_dict) # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict if not isinstance(model, FSDP): # Some portions of the model on rank 0 might not be on CPU, # move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113. for k, t in fsdp_state_dict.items(): if t.device != torch.device("cpu"): fsdp_state_dict[k] = t.cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() # if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict, strict=True) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)