def test_basic_save_and_load_state_dict(self, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() full_state_dict_mgr = self._get_full_state_dict_mgr( model, state_dict_rank0_and_offload) with full_state_dict_mgr: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_state_dict_rank0_offload_save_load_flow(self): # Test taking checkpoint on rank 0 only, and reload # without redundant CPU memories. model = TransformerWithSharedParams( group=dist.distributed_c10d._get_default_group()) my_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }) model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) ctx = self._get_state_dict_mgr(model, "state_dict", True) with ctx: state_dict = deepcopy(_get_state_dict(model)) # All ranks initialize non-FSDP model grp = dist.distributed_c10d._get_default_group() model_new = TransformerWithSharedParams(group=grp) for p in model_new.parameters(): with torch.no_grad(): p.zero_() # Only rank 0 loads the checkpoint if self.rank == 0: model_new.load_state_dict(state_dict) # TransformerWithSharedParams has a buffer of zeros, so can't pass in # self.assertNotEqual since the buffers would be equal. So just checking that # there is some difference in the model across ranks before state_dict is # broadcasted. with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) # FSDP with sync_module_states=True broadcasts the checkpointed states. model_new = FSDP(model_new, device_id=torch.cuda.current_device(), auto_wrap_policy=my_auto_wrap_policy, sync_module_states=True) # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint with FSDP.summon_full_params(model_new): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new)
def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) ]: model = model_call( checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) state_dict = _get_state_dict(model, False, False) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper model_new = model_call( checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict) self._compare_models(model, model_new, self.assertEqual)
def test_basic_save_and_load_state_dict(self, cpu_offload, fp16): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): """Tests saving the state dict, zeroing a target model's parameters, and loading the state dict, where the source and target models may have a checkpoint wrapper.""" for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) ]: model = model_call( checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) state_dict = _get_state_dict(model, False, False) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper model_new = model_call( checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual)
def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return for model_call in [ partial(self._get_non_fsdp_root_module, cpu_offload=cpu_offload), partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() ctx = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with ctx: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. if not isinstance(model, FSDP): # Move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113, some params # will still be on GPU for non FSDP root modules. for k in fsdp_state_dict.keys(): fsdp_state_dict[k] = fsdp_state_dict[k].cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)