def test_save_and_load_after_forward_state_dict(self): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) model = self._get_wrapped_model(group=torch.distributed.distributed_c10d._get_default_group()) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = model.module.get_input(torch.device("cuda")) output = model(*inp) loss = model.module.get_loss(inp, output).cuda() model.module.run_backward(loss) optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict state_dict = {k: v.clone() for k, v in model.state_dict().items()} _zero_model(model) # Load state_dict into zeroed model model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params)
def test_save_and_load_after_forward_state_dict(self, mixed_precision): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) mixed_precision = MixedPrecision() if mixed_precision else None model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict state_dict = {k: v.clone() for k, v in model.state_dict().items()} _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params)
def test_basic_save_and_load_state_dict(self, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() full_state_dict_mgr = self._get_full_state_dict_mgr( model, state_dict_rank0_and_offload) with full_state_dict_mgr: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_save_and_load_after_forward_state_dict( self, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) mixed_precision = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_full_state_dict_mgr(model, state_dict_rank0_and_offload) with fsd_mgr: state_dict = {k: v.clone() for k, v in model.state_dict().items()} self._validate_state_dict_contents(state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params)
def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) ]: model = model_call( checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) state_dict = _get_state_dict(model, False, False) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper model_new = model_call( checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict) self._compare_models(model, model_new, self.assertEqual)
def test_basic_save_and_load_state_dict(self, cpu_offload, fp16): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16)
def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): """Tests saving the state dict, zeroing a target model's parameters, and loading the state dict, where the source and target models may have a checkpoint wrapper.""" for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) ]: model = model_call( checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) state_dict = _get_state_dict(model, False, False) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper model_new = model_call( checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual)
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters())
def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = "", with_context: bool = False): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) if with_context: state_dict_type = { "state_dict": StateDictType.FULL_STATE_DICT, "local_state_dict": StateDictType.LOCAL_STATE_DICT, "sharded_state_dict": StateDictType.SHARDED_STATE_DICT, }[state_dict_type] with model.state_dict_type(state_dict_type): state_dict = model.state_dict() with blank_model.state_dict_type(state_dict_type): blank_model.load_state_dict(state_dict) else: state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters())
def test_state_dict_skip_module(self, state_dict_type, double_nest): torch.cuda.set_device(self.rank) def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names fsdp, linear_skip_tensor_names = _create_module() # Run a forward pass inp = torch.randn((1, 10), device=torch.cuda.current_device()) loss = fsdp(inp) loss.sum().backward() with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]): state_dict = fsdp.state_dict() if self.rank == 0 and state_dict_type != "local_state_dict": sd_keys = list(state_dict.keys()) expected = list(SkipModel(double_nest=False).state_dict().keys()) self.assertEqual(sorted(sd_keys), sorted(expected)) # TODO: parameters in linear_skip_tensor_names should not be handled # by FSDP.state_dict(). Have a check once this is implemented in # FSDP.state_dict(). # Check that it can be loaded into FSDP. new_fsdp, _ = _create_module() _zero_model(new_fsdp) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertNotEqual(p1, p2) with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]): if state_dict_type != "local_state_dict": # FlatParameter has not supported deepcopy yet. state_dict = deepcopy(state_dict) new_fsdp.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertEqual(p1, p2) # Test that the checkpoint can be loaded into a local model. local, _ = _create_module(wrap_fsdp=False) for param in local.parameters(): with torch.no_grad(): param.zero_() with fsdp.summon_full_params(fsdp): for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertNotEqual(p1, p2) if state_dict_type == "local_state_dict": return state_dict = _gather_state_dict(state_dict) with fsdp.summon_full_params(fsdp): if self.rank == 0: local.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertEqual(p1, p2)
def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return for model_call in [ partial(self._get_non_fsdp_root_module, cpu_offload=cpu_offload), partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() ctx = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with ctx: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. if not isinstance(model, FSDP): # Move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113, some params # will still be on GPU for non FSDP root modules. for k in fsdp_state_dict.keys(): fsdp_state_dict[k] = fsdp_state_dict[k].cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16)
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)