def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def test_summon_full_param_writeback(self, writeback, modify_outer): return _run_test_summon_full_param_writeback( self, writeback, cpu_offload=CPUOffload(offload_params=False), modify_outer=modify_outer, )
def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, cpu_offload) if wrap_fsdp: model = FSDP(model, cpu_offload=cpu_offload) else: model = DistributedDataParallel(model, device_ids=[self.rank]) model.half() optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(16, 2).cuda().half() in_data.requires_grad = True for _ in range(1): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters())
def __init__(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): super().__init__() # keep everything deterministic for model initialization torch.manual_seed(0) self.inner = torch.nn.Linear(2, 2).cuda() if wrap_fsdp: self.inner = FullyShardedDataParallel(self.inner, cpu_offload=cpu_offload) self.outer = torch.nn.Linear(2, 2).cuda()
def test_mixed_precision_e2e_full_shard(self): mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce self._run_test_mixed_precision_e2e( mp_config=mp, cpu_offload=CPUOffload(offload_params=True), backward_prefetch=None, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.FULL_SHARD, )
def test_summon_full_param_writeback(self, writeback, modify_outer, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None return _run_test_summon_full_param_writeback( self, writeback, modify_outer=modify_outer, cpu_offload=CPUOffload(offload_params=False), mixed_precision=mixed_precision, )
class TestPureFP16(FSDPTest): @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) def test_pure_fp16(self, cpu_offload: CPUOffload): """Tests pure FP16 training, including when the parameter's dtype is changed after FSDP initialization and before training.""" self._test_fsdp_parity( NestedWrappedModule, FSDPInitMode.RECURSIVE, cuda_init_mode=CUDAInitMode.CUDA_AFTER, # Run one iteration to avoid NaN without a gradient scaler num_iters=1, cpu_offload=cpu_offload, use_pure_fp16=True, )
class TestPureFP16(FSDPTest): def _dist_train(self, wrap_fsdp, cpu_offload=CPUOffload(offload_params=False)): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, cpu_offload) if wrap_fsdp: model = FSDP(model, cpu_offload=cpu_offload) else: model = DistributedDataParallel(model, device_ids=[self.rank]) model.half() optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(16, 2).cuda().half() in_data.requires_grad = True for _ in range(1): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: get_full_params(model) return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) def test_pure_fp16(self, cpu_offload): # DDP ddp_state = self._dist_train(wrap_fsdp=False) # FSDP fsdp_state = self._dist_train(wrap_fsdp=True, cpu_offload=cpu_offload) self.assertEqual(ddp_state, fsdp_state)
def test_mixed_precision_no_reshard_after_forward(self): # Note that we don't exercise all possible different configs so as to # not increase test TTS too much. mp = default_mp if not nccl_supports_bf16 else mp_diff_buffer_and_reduce self._run_test_mixed_precision_e2e( mp_config=mp, cpu_offload=CPUOffload(offload_params=True), backward_prefetch=None, full_precision_param_dtype=torch.float64, sharding_strategy=ShardingStrategy.SHARD_GRAD_OP, )
class TestSummonFullParamsNoShard(FSDPTest): @property def world_size(self): return 1 # does not shard @skip_if_lt_x_gpu(2) @parametrize("writeback", [True, False]) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("modify_outer", [True, False]) def test_summon_full_param_writeback(self, writeback, cpu_offload, modify_outer): return _run_test_summon_full_param_writeback( self, writeback, cpu_offload, modify_outer, )
def test_summon_full_params_equivalence(self): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) with model.summon_full_params(recurse=True): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) fsdp_params = deepcopy(list(model.parameters())) self.assertEqual(fsdp_params, list(local_model.parameters()))
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP( DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload ) local_model = DeterministicModel(wrap_fsdp=False) params_to_compare = ( [p.clone() for p in model.parameters()] if rank0_only and self.rank != 0 else list(local_model.parameters()) ) writeback = not rank0_only with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=writeback, offload_to_cpu=offload_to_cpu, ): if writeback: with torch.no_grad(): for p in model.parameters(): p.add_(1) for p in params_to_compare: p.add_(1) # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare) # CPU offload is enabled for main API, so we should point back to CPU for param in model.parameters(): self.assertEqual(param.device, torch.device("cpu"))
class TestFSDPStateDict(FSDPTest): @property def world_size(self): return 2 def _broadcast_state_dict(self, state_dict): olist = [state_dict if self.rank == 0 else None] dist.broadcast_object_list(olist) return olist[0] def _compare_models(self, model, model_new, assert_fn, check_fp16=False): with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) assert_fn(params, params_new) if check_fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16) def _get_simple_nested_model(self, *fsdp_args, wrap=True, checkpoint_wrap=False, **fsdp_kwargs): if wrap: lin1 = nn.Linear(10, 10, bias=False).cuda() lin2 = nn.Linear(10, 10, bias=False).cuda() if checkpoint_wrap: lin1 = checkpoint_wrapper(lin1) lin2 = checkpoint_wrapper(lin2) seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2) if checkpoint_wrap: seq = checkpoint_wrapper(seq) model = FSDP(seq, *fsdp_args, **fsdp_kwargs) else: model = nn.Sequential( nn.Linear(10, 10, bias=False).cuda(), nn.Linear(10, 10, bias=False).cuda()) return model def _get_simple_model(self, *fsdp_args, checkpoint_wrap=False, **fsdp_kwargs): lin = nn.Linear(10, 10, bias=False).cuda() if checkpoint_wrap: lin = checkpoint_wrapper(lin) model = FSDP(lin, *fsdp_args, **fsdp_kwargs) return model def _get_non_fsdp_root_module(self, *fsdp_args, wrap=True, **fsdp_kwargs): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2): super().__init__() self.non_fsdp_lin = nn.Linear(10, 10, bias=False).cuda() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 def forward(self, x): x = self.non_fsdp_lin(x) x = self.fsdp_1(x) x = self.fsdp_2(x) return x return FSDPContainer( self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs), self._get_simple_nested_model(*fsdp_args, wrap=wrap, **fsdp_kwargs), ) def _get_state_dict_mgr( self, model: nn.Module, state_dict_type: str, state_dict_rank0_and_offload: bool, ): _state_dict_type = STATE_DICT_MAPPING[state_dict_type] if state_dict_type == "state_dict": config = FullStateDictConfig( rank0_only=state_dict_rank0_and_offload, offload_to_cpu=state_dict_rank0_and_offload, ) else: config = None return FSDP.state_dict_type(model, _state_dict_type, config) def _validate_state_dict_contents(self, model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None): if state_dict_rank0_and_offload: if self.rank == 0: self.assertNotEqual(fsdp_state_dict, {}) for key, tensor in fsdp_state_dict.items(): if ignore_keys and key in ignore_keys: continue self.assertEqual( tensor.device, torch.device("cpu"), f"{key} is unexpectedly on device {tensor.device}", ) else: # For non-FSDP roots, the non FSDP portion can still have parameters on rank 0, # so bypass the check for now. if isinstance(model, FSDP): self.assertEqual(fsdp_state_dict, {}) @skip_if_lt_x_gpu(2) @parametrize("checkpoint_wrap", ["first", "second", "both"]) def test_fsdp_state_dict_with_activation_checkpoint(self, checkpoint_wrap): """Tests saving the state dict, zeroing a target model's parameters, and loading the state dict, where the source and target models may have a checkpoint wrapper.""" for model_call in [ partial(self._get_simple_model), partial(self._get_simple_nested_model) ]: model = model_call( checkpoint_wrap=(checkpoint_wrap in ["first", "both"])) state_dict = _get_state_dict(model, False, False) # Possibly wrap new model in activation checkpoint wrapper to test save/ # load with this wrapper model_new = model_call( checkpoint_wrap=(checkpoint_wrap in ["second", "both"])) _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Would fail if checkpoint_wrapper did not correctly implement state_dict pre/post hooks model_new.load_state_dict(state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual) @skip_if_lt_x_gpu(2) def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("fp16", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) def test_basic_save_and_load_state_dict(self, state_dict_type, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return for model_call in [ partial(self._get_non_fsdp_root_module, cpu_offload=cpu_offload), partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() ctx = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with ctx: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) self._compare_models(model, model_new, self.assertNotEqual) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. if not isinstance(model, FSDP): # Move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113, some params # will still be on GPU for non FSDP root modules. for k in fsdp_state_dict.keys(): fsdp_state_dict[k] = fsdp_state_dict[k].cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() with FSDP.state_dict_type(model_new, STATE_DICT_MAPPING[state_dict_type]): model_new.load_state_dict(fsdp_state_dict, strict=True) self._compare_models(model, model_new, self.assertEqual, check_fp16=fp16) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) @parametrize("mixed_precision", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) def test_save_and_load_after_forward_state_dict( self, state_dict_type, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return torch.cuda.set_device(self.rank) mixed_precision = (MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) if mixed_precision else None) model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = get_full_params(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = get_full_params(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with fsd_mgr: state_dict = model.state_dict() if state_dict_type == "state_dict": state_dict = {k: v.clone() for k, v in state_dict.items()} else: for sharded_tensor in state_dict.values(): shard = sharded_tensor._local_shards[0] shard.tensor = shard.tensor.clone().detach_() self._validate_state_dict_contents(model, state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() with FSDP.state_dict_type(model, STATE_DICT_MAPPING[state_dict_type]): model.load_state_dict(state_dict, strict=True) loaded_params = get_full_params(model) self.assertEqual(loaded_params, trained_params) def _initialize_model( self, wrap_fsdp: bool, wrap_ddp: bool = True, register_buffers: bool = False, ): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp, register_buffers=register_buffers).cuda() if wrap_fsdp: model = FSDP(model) elif wrap_ddp: model = DistributedDataParallel(model, device_ids=[self.rank]) return model @staticmethod def _state_dict(model: Module, state_dict_type: str): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict type for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.state_dict() @staticmethod def _load_state_dict(model: Module, state_dict_type: str, state_dict: Dict[str, Any]): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.load_state_dict(state_dict, strict=True) def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_save_load_flow(self, state_dict_type): fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type) ddp_params = self._dist_train(wrap_fsdp=False) self.assertEqual(ddp_params, fsdp_params) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_fsdp_state_dict_keys(self, state_dict_type): state_dict = self._state_dict(self._initialize_model(True), state_dict_type) if state_dict_type == "local_state_dict": self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys()) elif state_dict_type in ("state_dict", "sharded_state_dict"): # Keys should match local model. local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) local_keys = local_model.state_dict().keys() self.assertEqual(state_dict.keys(), local_keys) else: raise NotImplementedError(f"No test for {state_dict_type}!") @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _UNFLATTENED_STATE_DICT_IMPLS) @parametrize("state_dict_rank0_and_offload", [True, False]) @parametrize("fsdp_root", [True, False]) def test_state_dict_load_into_local_module( self, state_dict_type, state_dict_rank0_and_offload, fsdp_root, ): """ Tests that FSDP's state_dict can be loaded into a local model. """ if state_dict_rank0_and_offload and state_dict_type != "state_dict": return if not fsdp_root: model = self._get_non_fsdp_root_module() else: model = self._initialize_model(wrap_fsdp=True, register_buffers=True) optim = SGD(model.parameters(), lr=0.1) if not fsdp_root: in_data = torch.randn(1, 10, requires_grad=True, device=torch.device("cuda")) else: in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_state_dict_mgr(model, state_dict_type, state_dict_rank0_and_offload) with sd_mgr: fsdp_state_dict = model.state_dict() ignore_keys = [ k for k in fsdp_state_dict.keys() if NON_ROOT_FSDP_PREFIX in k ] self._validate_state_dict_contents( model, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=ignore_keys, ) # Create zeroed local model if not fsdp_root: blank_local_model = self._get_non_fsdp_root_module(wrap=False) else: blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False, register_buffers=True) # Nothing should be FSDP for mod in blank_local_model.modules(): self.assertFalse(isinstance(mod, FSDP)) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() fsdp_state_dict = _gather_state_dict(fsdp_state_dict) # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict if not isinstance(model, FSDP): # Some portions of the model on rank 0 might not be on CPU, # move everything to CPU to avoid running into # https://github.com/pytorch/pytorch/issues/77113. for k, t in fsdp_state_dict.items(): if t.device != torch.device("cpu"): fsdp_state_dict[k] = t.cpu() fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() # if self.rank == 0: blank_local_model.load_state_dict(fsdp_state_dict, strict=True) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) @parametrize("double_nest", [True]) def test_state_dict_skip_module(self, state_dict_type, double_nest): torch.cuda.set_device(self.rank) def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names fsdp, linear_skip_tensor_names = _create_module() # Run a forward pass inp = torch.randn((1, 10), device=torch.cuda.current_device()) loss = fsdp(inp) loss.sum().backward() with FSDP.state_dict_type(fsdp, STATE_DICT_MAPPING[state_dict_type]): state_dict = fsdp.state_dict() if self.rank == 0 and state_dict_type != "local_state_dict": sd_keys = list(state_dict.keys()) expected = list(SkipModel(double_nest=False).state_dict().keys()) self.assertEqual(sorted(sd_keys), sorted(expected)) # TODO: parameters in linear_skip_tensor_names should not be handled # by FSDP.state_dict(). Have a check once this is implemented in # FSDP.state_dict(). # Check that it can be loaded into FSDP. new_fsdp, _ = _create_module() _zero_model(new_fsdp) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertNotEqual(p1, p2) with FSDP.state_dict_type(new_fsdp, STATE_DICT_MAPPING[state_dict_type]): if state_dict_type != "local_state_dict": # FlatParameter has not supported deepcopy yet. state_dict = deepcopy(state_dict) new_fsdp.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertEqual(p1, p2) # Test that the checkpoint can be loaded into a local model. local, _ = _create_module(wrap_fsdp=False) for param in local.parameters(): with torch.no_grad(): param.zero_() with fsdp.summon_full_params(fsdp): for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertNotEqual(p1, p2) if state_dict_type == "local_state_dict": return state_dict = _gather_state_dict(state_dict) with fsdp.summon_full_params(fsdp): if self.rank == 0: local.load_state_dict(state_dict, strict=True) for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertEqual(p1, p2) @skip_if_lt_x_gpu(2) def test_wrong_state_dict_config(self): model = FSDP(Model(wrap_fsdp=True).cuda()) with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()): pass @skip_if_lt_x_gpu(2) @parametrize("prefix", [True, False]) @parametrize("ignore_inner", [True, False]) def test_state_dict_with_ignored_modules(self, prefix, ignore_inner): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } if ignore_inner: ignored_tensor_to_tensor_name = { **ignored_tensor_to_tensor_name, model.inner.bias: "inner.bias", model.inner.weight: "inner.weight", } # Note that when model.inner is not ignored this test also ensures # non-ignored buffers are not cloned. buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) prefix_str = "foo." if prefix else "" with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict(prefix=prefix_str) with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[prefixed_tensor_name].data_ptr(), f"{prefixed_tensor_name}") # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() to_load = {k[len(prefix_str):]: v for k, v in sd1.items()} nonwrapped_model.load_state_dict(to_load, strict=True) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict(prefix=prefix_str) for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr()) @skip_if_lt_x_gpu(2) def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
class TestGradAcc(FSDPTest): """Tests ``FullyShardedDataParallel``'s gradient accumulation via both its ``no_sync()`` context manager and without the context manager.""" def _test_grad_acc( self, batch_dim: int, configs: List[_GradAccConfig], cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """ Tests gradient accumulation by comparing a run that trains sequentially through some batches while accumulating gradients with a run that trains on the concatenation of those batches in a single iteration. The last iteration always synchronizes gradients regardless of what is specified by the last element of ``configs``. Arguments: batch_dim (int): Batch dimension in the input tensor to be passed into the model for the forward pass. configs (List[_GradAccConfig]): :class:`list` of configurations specifying how gradients are accumulated; for example, a list corresponding to [(False, 2), (True, 2), (False, 2)] indicates to accumulate over 2 + 2 + 2 = 6 total iterations, where the first two do not use ``no_sync()``, the middle two do use ``no_sync()``, and the final two again do not use ``no_sync()``. cpu_offload (CPUOffload): Configures CPU offloading. backward_prefetch (Optional[BackwardPrefetch]): Specifies at which point to prefetch the next layer's full parameters during the backward pass, if at all. """ # Gradient accumulation outside `no_sync()` is not currently compatible # with CPU offloading if cpu_offload.offload_params and \ any(not config.use_no_sync for config in configs): return old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 try: # Disable TF32 to prevent floating point drift torch.backends.cuda.matmul.allow_tf32 = False # Initialize the FSDP model and optimizer group = dist.distributed_c10d._get_default_group() fsdp_model: FSDP = self._get_wrapped_model( group, cuda_first=False, add_bn=False, config={ "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, }, ) # disable BN since the test uses varying batch sizes fsdp_model.eval() # disable dropout device = torch.device("cuda") optim = torch.optim.SGD( fsdp_model.parameters(), lr=0.01, momentum=0.9, ) # Generate the sequence of batches, each containing the same data # but permuted def permute_tensor(x: torch.Tensor): return x.view(-1)[torch.randperm(x.numel())].view_as(x) batch: Tuple[torch.Tensor, ...] = \ fsdp_model.module.get_input(device) batches: List[Tuple[torch.Tensor, ...]] = [batch] num_iters_to_acc = sum(config.num_iters for config in configs) for _ in range(num_iters_to_acc - 1): batches.append(tuple(permute_tensor(t) for t in batch)) for (batch1, batch2) in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): assert not torch.all(t1 == t2), \ "Check the test to make sure that batches are distinct" # Concatenate the batches along the given batch dimension concat_batch: Tuple[torch.Tensor, ...] = tuple( torch.cat(ts, dim=batch_dim) for ts in zip(*batches)) # Establish reference gradients using the concatenated batch fsdp_model.zero_grad() output = fsdp_model(*concat_batch) ref_loss = fsdp_model.module.get_loss(concat_batch, output) ref_loss.backward() ref_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compute and accumulate the gradients fsdp_model.zero_grad() losses = [] batch_idx = 0 for config in configs: sync_context = fsdp_model.no_sync() if config.use_no_sync \ else contextlib.suppress() with sync_context: for _ in range(config.num_iters): if batch_idx == num_iters_to_acc - 1: break # always sync on the last iteration batch = batches[batch_idx] batch_idx += 1 output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() losses.append(loss) output = fsdp_model(*batches[-1]) loss = fsdp_model.module.get_loss(batches[-1], output) loss.backward() losses.append(loss) acc_loss = sum(losses) acc_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compare the losses and gradients torch.testing.assert_close(ref_loss, acc_loss) self.assertEqual(len(ref_grads), len(acc_grads)) for ref_grad, acc_grad in zip(ref_grads, acc_grads): self.assertEqual(ref_grad.device, acc_grad.device) self.assertEqual(ref_grad.size(), acc_grad.size()) self.assertEqual(ref_grad.dtype, acc_grad.dtype) torch.testing.assert_close(ref_grad, acc_grad) # Check that the optimizer step does not error optim.step() finally: torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32 @skip_if_lt_x_gpu(2) @parametrize("configs", [ _GradAccConfigs([ _GradAccConfig(use_no_sync=True, num_iters=3), _GradAccConfig(use_no_sync=False, num_iters=3), _GradAccConfig(use_no_sync=True, num_iters=3), ]), _GradAccConfigs([ _GradAccConfig(use_no_sync=False, num_iters=3), _GradAccConfig(use_no_sync=True, num_iters=3), _GradAccConfig(use_no_sync=False, num_iters=3), ]), ]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], ) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None], ) def test_grad_acc( self, configs: _GradAccConfigs, cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """ Tests gradient accumulation. This exercises gradient accumulation inside and outside the ``no_sync()`` context manager, in particular by interleaving the two. It tests both interleaving starting with (and ending with, resp.) inside versus outside ``no_sync()`` to ensure that initial conditions (and final conditions, resp.) do not affect the correctness. This test also checks for compatibility with the CPU offload and backward prefetch options. NOTE: Gradient accumulation without using the ``no_sync()`` context manager is not currently compatible with CPU offloading, so those tests are vacuous. """ self._test_grad_acc( batch_dim=1, configs=configs.configs, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, )
class TestSummonFullParams(FSDPTest): @property def world_size(self): return 2 def get_model_param_count(self, m): return sum([p.numel() for p in m.parameters()]) # padding ensures that all shards have the same size with the least amount of padding def get_expected_sharded_size(self, global_size): return int(math.ceil(global_size / self.world_size)) @skip_if_lt_x_gpu(2) @parametrize("writeback", [True, False]) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("mixed_precision", [True, False]) @parametrize("modify_outer", [True, False]) def test_summon_full_param_writeback(self, writeback, cpu_offload, mixed_precision, modify_outer): mixed_precision = MixedPrecision() if mixed_precision else None return _run_test_summon_full_param_writeback( self, writeback, modify_outer, cpu_offload=cpu_offload, mixed_precision=mixed_precision, ) @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu())) @skip_if_lt_x_gpu(2) @parametrize("recurse", [True, False]) @parametrize("summon_outer", [True, False]) @parametrize("mixed_precision", [True, False]) def test_summon_full_param_recursive(self, recurse, summon_outer, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) global_inner_numel = self.get_model_param_count( nn.Linear(5, 5, bias=False)) global_outer_numel = self.get_model_param_count( nn.Linear(5, 3, bias=False)) shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) model_to_summon = model if summon_outer else model[0] # outer is summoned if _summon_full_param is called on the outer FSDP module expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module expected_inner_numel = (global_inner_numel if recurse or not summon_outer else shard_inner_numel) with model_to_summon.summon_full_params(model_to_summon, recurse=recurse): self.assertEqual(expected_outer_numel, outer_param.numel()) self.assertEqual(expected_inner_numel, inner_param.numel()) @skip_if_lt_x_gpu(2) def test_cannot_summon_full_params_from_forward(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.a = nn.Parameter(torch.zeros(5)) def forward(self, fsdp_module): with fsdp_module.summon_full_params(fsdp_module): pass model = FSDP(MyModule()).cuda(self.rank) with self.assertRaisesRegex(ValueError, "current state is TrainingState_.FORWARD"): model(model) @skip_if_lt_x_gpu(2) def test_cannot_summon_full_params_from_backward(self): model = FSDP(nn.Linear(2, 1)).cuda(self.rank) output = model(torch.ones(2).cuda(self.rank)) def bad_backwards_hook(tensor): with model.summon_full_params(model): pass return None self.assertTrue(output.requires_grad) output.register_hook(bad_backwards_hook) with self.assertRaisesRegex( ValueError, "current state is TrainingState_.BACKWARD_PRE"): output.backward() @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) def test_summon_full_params_respects_reshard_after_forward( self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly summon_full_params should have the same behavior with model.summon_full_params(model): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) @skip_if_lt_x_gpu(2) def test_summon_single_param(self): model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank) p = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, p.numel()) with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model.summon_full_params(model, writeback=True): self.assertEqual(1, p.numel()) with torch.no_grad(): p.copy_(torch.zeros_like(p)) # most ranks hold no data and wrote to padding so only rank zero will observe the above write if self.rank == 0: self.assertEqual(0, p[0]) else: self.assertEqual(self.rank + 2, p[0]) @skip_if_lt_x_gpu(2) @parametrize("rank0_only", [True, False]) @parametrize("offload_to_cpu", [True, False]) def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) writeback = not rank0_only with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=writeback, offload_to_cpu=offload_to_cpu, ): if writeback: with torch.no_grad(): for p in model.parameters(): p.add_(1) for p in params_to_compare: p.add_(1) # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare) # CPU offload is enabled for main API, so we should point back to CPU for param in model.parameters(): self.assertEqual(param.device, torch.device("cpu")) @skip_if_lt_x_gpu(2) def test_summon_from_non_fsdp(self): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2, fsdp_3): super().__init__() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 self.fsdp_3 = fsdp_3 model_fsdp = FSDPContainer( FSDP(DeterministicModel(wrap_fsdp=True)), FSDP(DeterministicModel(wrap_fsdp=True)), DeterministicModel(wrap_fsdp=False), ) model_no_fsdp = FSDPContainer( DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), ) params_to_compare = list(model_no_fsdp.parameters()) with FSDP.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params) @skip_if_lt_x_gpu(2) @parametrize("rank0_only", [True, False]) @parametrize("offload_to_cpu", [True, False]) @parametrize("mixed_precision", [True, False]) def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 1, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) @skip_if_lt_x_gpu(2) @parametrize("rank0_only", [True, False]) @parametrize("offload_to_cpu", [True, False]) @parametrize("mixed_precision", [True, False]) def test_params_are_unflattenned(self, rank0_only, offload_to_cpu, mixed_precision): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP(deepcopy(model), mixed_precision=mixed_precision).cuda(self.rank) def _get_flat_param(): return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params( fsdp_model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): if self.rank == 0 or not rank0_only: self.assertEqual(fsdp_model.weight.shape, model.weight.shape) expected_device = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) self.assertTrue(expected_device == fsdp_model.weight.device) else: # Nonzero rank with rank0_only maintains original params. flat_within_ctx = _get_flat_param() self.assertEqual(flat_within_ctx, flattened_param) self.assertEqual(flat_within_ctx.device, torch.device(torch.cuda.current_device())) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device())) @skip_if_lt_x_gpu(2) @parametrize("rank0_only", [True, False]) @parametrize("offload_to_cpu", [True, False]) @parametrize("mixed_precision", [True, False]) def test_params_count_and_value( self, rank0_only: bool, offload_to_cpu: bool, mixed_precision: bool, ): mixed_precision = MixedPrecision() if mixed_precision else None model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) fsdp_model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with FSDP.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device())) @skip_if_lt_x_gpu(2) def test_raises_rank0_with_writeback(self): """Tests that ``summon_full_params()`` with both ``rank0_only=True`` and ``writeback=True`` raises an error.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) with self.assertRaisesRegex(ValueError, "is not supported"): with FSDP.summon_full_params(nested_wrapped_module, rank0_only=True, writeback=True): pass @skip_if_lt_x_gpu(2) @parametrize("prefix", ["", "test_prefix"]) @parametrize("recurse", [False, True]) def test_named_parameters_buffers(self, prefix: str, recurse: bool): """Tests that ``named_parameters()`` and ``named_buffers()`` for a top-level FSDP-wrapped model matches their behavior for the equivalent non-wrapped model.""" model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def _test_identical_outputs(self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, sharding_strategy=None, save_model=True, clip_norm=0.3, norm_type=None, **kwargs): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: model = ref_ddp_fn(model) # DDP training ref_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) except Exception as e: raise ValueError( f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel( model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress()) with ctx: # FSDP training shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual({torch.device("cpu")}, device_set, f"Got device set {device_set}") shard_full_params = get_full_params(model) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
def _test_fsdp_parity( self, model_class: Type[FSDPTestModel], fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, ref_init_fn: Optional[Callable] = None, num_iters: int = 2, save_model: bool = True, cpu_offload: CPUOffload = CPUOffload(), backward_prefetch: Optional[BackwardPrefetch] = None, forward_prefetch: bool = False, sharding_strategy: Optional[ShardingStrategy] = None, mixed_precision: Optional[MixedPrecision] = None, enable_sharded_grad_scaler: bool = False, use_pure_fp16: bool = False, norm_type: Optional[Union[float, int]] = None, init_kwargs: Optional[Dict[str, Any]] = None, **fsdp_kwargs, ): """ Tests FSDP training against a reference, which defaults to DDP but may be customized with ``ref_init_fn``. Args: model_class (Type[FSDPTestModel]): A model class that inherits from ``FSDPTestModel``, which defines the expected interface. fsdp_init_mode (FSDPInitMode): The mode to initialize the FSDP-wrapped model. This should not be ``NO_FSDP``. ref_init_fn (Optional[Callable]): A callable to invoke that wraps a non-wrapped model to construct the reference model, where this wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP" if init_kwargs is None: init_kwargs = {} lr = 1e-2 rank = self.process_group.rank() # Establish reference behavior with DDP model = model_class.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, **init_kwargs, ) if ref_init_fn is None: ref_model = DDP(model, device_ids=[rank], output_device=rank) else: ref_model = ref_init_fn(model) if use_pure_fp16: ref_model = ref_model.half() ref_loss = self._train_for_several_steps( ref_model, num_iters, autocast=mixed_precision is not None, lr=lr, fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) ddp_params = list(ref_model.parameters()) # Check against FSDP behavior fsdp_kwargs.update({ "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, "forward_prefetch": forward_prefetch, "sharding_strategy": sharding_strategy, "mixed_precision": mixed_precision, }) try: fsdp_model = model_class.init( self.process_group, fsdp_init_mode, cuda_init_mode, fsdp_kwargs, deterministic=True, **init_kwargs, ) except Exception as e: raise ValueError( f"Initializing {model_class} raised error {str(e)}") if not isinstance(fsdp_model, FSDP): # Enforce that we wrap with top-level FSDP since we are comparing # assuming a data parallel reference and some test models may not # do so in their `init()` method fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) if use_pure_fp16: # Change the model parameter dtype after FSDP initialization fsdp_model = fsdp_model.half() if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() offload_params = cpu_offload is not None and cpu_offload.offload_params # Offloading parameters with `CUDA_AFTER` should raise an error during # lazy initialization due to the parameter devices not being CPU; # otherwise, all parameter devices should be CPU expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER if expects_cpu_device: cpu_device = torch.device("cpu") for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) context = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if expects_device_error else suppress()) with context: fsdp_loss = self._train_for_several_steps( fsdp_model, num_iters, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) # No need to check for parameter and loss parity if expecting an error if expects_device_error: return # Check parameter devices are CPU if offloading to CPU before calling # `get_full_params()`, which will cast the parameters to FP32 if offload_params: for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) fsdp_loss = fsdp_loss.cuda() fsdp_unsharded_params = get_full_params(fsdp_model) torch.testing.assert_allclose(ref_loss, fsdp_loss) # Do not check for parameter parity if using mixed precision since (1) # the DDP parameters are in FP16 (from `half()`) while the FSDP # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs # the optimizer in FP16 while FSDP runs it in FP32 if mixed_precision is not None: self.assertEqual( ddp_params, fsdp_unsharded_params, exact_device=True, msg="FSDP did not match DDP", )
mp_configs = [ default_mp, mp_only_reduce, mp_only_param_and_buf, mp_no_mixed_precision ] if nccl_supports_bf16: mp_diff_buffer_and_reduce = MixedPrecision(param_dtype=torch.float16, buffer_dtype=torch.bfloat16, reduce_dtype=torch.float32) mp_configs.extend([mp_diff_buffer_and_reduce]) # Buffer original dtype, which can differ from model params. _BUFFER_ORIG_DTYPE = torch.float64 params = "mp_config,cpu_offload,backward_prefetch,full_precision_param_dtype" cpu_offload_config = [ CPUOffload(offload_params=True), CPUOffload(offload_params=False) ] backward_prefetch_config = [ BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST ] full_precision_param_dtype_config = [torch.float32, torch.float64] configs = list( product( mp_configs, cpu_offload_config, backward_prefetch_config, full_precision_param_dtype_config, )) test_name_mapping = {
class TestNoSync(FSDPTest): """Tests ``FullyShardedDataParallel``'s gradient accumulation via its ``no_sync()`` context manager.""" def _test_no_sync( self, batch_dim: int, num_iters_to_acc: int, cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """ Tests ``no_sync()`` by comparing a run that trains sequentially through some batches while accumulating gradients with a run that trains on the concatenation of those batches in a single iteration. The number of batches, i.e. the number of iterations for which to accumulate gradients, is given by ``num_iters_to_acc``. Arguments: batch_dim (int): Batch dimension in the input tensor to be passed into the model for the forward pass. num_iters_to_acc (int): Number of iterations for which to accumulate gradients; all but the last iteration are run using the ``no_sync()`` context manager so that gradients are not synchronized until the final iteration. cpu_offload (CPUOffload): Configures CPU offloading. backward_prefetch (Optional[BackwardPrefetch]): Specifies at which point to prefetch the next layer's full parameters during the backward pass, if at all. """ old_allow_tf32 = torch.backends.cuda.matmul.allow_tf32 try: # Disable TF32 to prevent floating point drift torch.backends.cuda.matmul.allow_tf32 = False # Initialize the FSDP model and optimizer group = dist.distributed_c10d._get_default_group() fsdp_model: FSDP = self._get_wrapped_model( group, cuda_first=False, add_bn=False, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) # disable BN since the test uses varying batch sizes fsdp_model.eval() # disable dropout device = torch.device("cuda") optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.01, momentum=0.9) # Generate the sequence of batches, each containing the same data but # permuted def permute_tensor(x: torch.Tensor): return x.view(-1)[torch.randperm(x.numel())].view_as(x) batch: Tuple[torch.Tensor, ...] = fsdp_model.module.get_input(device) batches: List[Tuple[torch.Tensor, ...]] = [batch] for _ in range(num_iters_to_acc - 1): batches.append(tuple(permute_tensor(t) for t in batch)) for (batch1, batch2) in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): assert not torch.all(t1 == t2) # Concatenate the batches along the given batch dimension concat_batch: Tuple[torch.Tensor, ...] = tuple( torch.cat(ts, dim=batch_dim) for ts in zip(*batches)) # Establish reference gradients using the concatenated batch fsdp_model.zero_grad() output = fsdp_model(*concat_batch) ref_loss = fsdp_model.module.get_loss(concat_batch, output) ref_loss.backward() ref_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compute the gradients by accumulating via `no_sync()` fsdp_model.zero_grad() losses = [] with fsdp_model.no_sync(): for batch in batches[: -1]: # accumulate for all but the last batch output = fsdp_model(*batch) loss = fsdp_model.module.get_loss(batch, output) loss.backward() losses.append(loss) output = fsdp_model(*batches[-1]) loss = fsdp_model.module.get_loss(batches[-1], output) loss.backward() losses.append(loss) acc_loss = sum(losses) acc_grads = [ p.grad.detach().clone() for p in fsdp_model.parameters() ] # Compare the losses and gradients torch.testing.assert_allclose(ref_loss, acc_loss) assert len(ref_grads) == len(acc_grads) for ref_grad, acc_grad in zip(ref_grads, acc_grads): assert ref_grad.device == acc_grad.device assert ref_grad.size() == acc_grad.size() assert ref_grad.dtype == acc_grad.dtype torch.testing.assert_allclose(ref_grad, acc_grad) # Check that the optimizer step does not error optim.step() finally: torch.backends.cuda.matmul.allow_tf32 = old_allow_tf32 @skip_if_lt_x_gpu(2) @parametrize( "num_iters_to_acc", [2, 4], ) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)], ) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_no_sync( self, num_iters_to_acc: int, cpu_offload: CPUOffload, backward_prefetch: Optional[BackwardPrefetch], ): """Tests the ``no_sync()`` context manager.""" assert num_iters_to_acc >= 2, \ "Accumulate for at least 2 iterations to be nontrivial" self._test_no_sync( batch_dim=1, num_iters_to_acc=num_iters_to_acc, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, )
run_tests, ) if not dist.is_available(): print("Distributed not available, skipping tests", file=sys.stderr) sys.exit(0) if TEST_WITH_DEV_DBG_ASAN: print( "Skip dev-asan as torch + multiprocessing spawn have known issues", file=sys.stderr, ) sys.exit(0) params = "cpu_offload,backward_prefetch,forward_prefetch,sharding_strategy" cpu_offload_config = [CPUOffload(offload_params=True), CPUOffload(offload_params=False)] backward_prefetch_config = [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None] forward_prefetch_config = ["forward_prefetch", "no_forward_prefetch"] sharding_strategy_config = [ShardingStrategy.SHARD_GRAD_OP, None, ShardingStrategy.NO_SHARD] configs = list(itertools.product(cpu_offload_config, backward_prefetch_config, forward_prefetch_config, sharding_strategy_config)) test_name_mapping = { str(CPUOffload(offload_params=True)): "offload_true", str(CPUOffload(offload_params=False)): "offload_false", str(BackwardPrefetch.BACKWARD_PRE): "backward_prefetch_pre", str(BackwardPrefetch.BACKWARD_POST): "backward_prefetch_post", "forward_prefetch": "forward_prefetch", "no_forward_prefetch": "no_forward_prefetch", str(ShardingStrategy.SHARD_GRAD_OP): "shard_grad_op",
class TestSummonFullParams(FSDPTest): @property def world_size(self): return 2 def get_model_param_count(self, m): return sum([p.numel() for p in m.parameters()]) # padding ensures that all shards have the same size with the least amount of padding def get_expected_sharded_size(self, global_size): return int(math.ceil(global_size / self.world_size)) @skip_if_lt_x_gpu(2) @parametrize("writeback", [True, False]) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("modify_outer", [True, False]) def test_summon_full_param_writeback( self, writeback, cpu_offload, modify_outer ): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback: self.assertEqual(p.cpu()[0], 0) else: self.assertEqual(p.cpu()[0], self.rank + 2) @skip_if_lt_x_gpu(2) def test_summon_full_param_shard_value(self): raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank)) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flatenned param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model._summon_full_params(): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParameter(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0 : my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu()) ) @skip_if_lt_x_gpu(2) @parametrize("recurse", [True, False]) @parametrize("summon_outer", [True, False]) def test_summon_full_param_recursive(self, recurse, summon_outer): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False)) global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False)) shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) model_to_summon = model if summon_outer else model[0] # outer is summoned if _summon_full_param is called on the outer FSDP module expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module expected_inner_numel = ( global_inner_numel if recurse or not summon_outer else shard_inner_numel ) with model_to_summon._summon_full_params(recurse=recurse): self.assertEqual(expected_outer_numel, outer_param.numel()) self.assertEqual(expected_inner_numel, inner_param.numel()) @skip_if_lt_x_gpu(2) def test_cannot_summon_full_params_from_forward(self): class MyModule(nn.Module): def __init__(self): super().__init__() self.a = nn.Parameter(torch.zeros(5)) def forward(self, fsdp_module): with fsdp_module._summon_full_params(): pass model = FSDP(MyModule()).cuda(self.rank) with self.assertRaisesRegex( ValueError, "current state is TrainingState_.FORWARD" ): model(model) @skip_if_lt_x_gpu(2) def test_cannot_summon_full_params_from_backward(self): model = FSDP(nn.Linear(2, 1)).cuda(self.rank) output = model(torch.ones(2).cuda(self.rank)) def bad_backwards_hook(tensor): with model._summon_full_params(): pass return None self.assertTrue(output.requires_grad) output.register_hook(bad_backwards_hook) with self.assertRaisesRegex( ValueError, "current state is TrainingState_.BACKWARD_PRE" ): output.backward() @skip_if_lt_x_gpu(2) def test_summon_full_params_respects_reshard_after_forward(self): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly _summon_full_params should have the same behavior with model._summon_full_params(): pass self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) @skip_if_lt_x_gpu(2) def test_summon_single_param(self): model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank) p = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, p.numel()) with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=True): self.assertEqual(1, p.numel()) with torch.no_grad(): p.copy_(torch.zeros_like(p)) # most ranks hold no data and wrote to padding so only rank zero will observe the above write if self.rank == 0: self.assertEqual(0, p[0]) else: self.assertEqual(self.rank + 2, p[0]) @skip_if_lt_x_gpu(2) def test_reshard_outside_forward_backward_iteration(self): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 1, bias=False) ) ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) with model._summon_full_params(): pass self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model._summon_full_params(): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) @skip_if_lt_x_gpu(2) def test_params_are_unflatenned(self): model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank) flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, flattened_param.numel()) with model._summon_full_params(): a = model.weight.flatten().detach() b = flattened_param.detach() self.assertTrue(torch.equal(a, b)) @skip_if_lt_x_gpu(2) def test_params_count_and_value(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) with fsdp_model._summon_full_params(): for p1, p2 in itertools.zip_longest( fsdp_model.parameters(), model.module.parameters() ): self.assertEqual(p1, p2)
class TestFSDPStateDict(FSDPTest): @property def world_size(self): return 2 def _broadcast_state_dict(self, state_dict): olist = [state_dict if self.rank == 0 else None] dist.broadcast_object_list(olist) return olist[0] def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Sequential( FSDP( nn.Linear(10, 10, bias=False).cuda(), *fsdp_args, **fsdp_kwargs), nn.Linear(10, 10, bias=False).cuda(), ), *fsdp_args, **fsdp_kwargs, ) return model def _get_simple_model(self, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Linear(10, 10, bias=False).cuda(), *fsdp_args, **fsdp_kwargs) return model def _get_full_state_dict_mgr(self, model, state_dict_rank0_and_offload): return FSDP.state_dict_type( model, StateDictType.FULL_STATE_DICT, FullStateDictConfig( rank0_only=state_dict_rank0_and_offload, offload_to_cpu=state_dict_rank0_and_offload, )) def _validate_state_dict_contents(self, fsdp_state_dict, state_dict_rank0_and_offload, ignore_keys=None): if state_dict_rank0_and_offload: if self.rank == 0: self.assertNotEqual(fsdp_state_dict, {}) for key, tensor in fsdp_state_dict.items(): if ignore_keys and key in ignore_keys: continue self.assertEqual( tensor.device, torch.device("cpu"), f"{key} is unexpectedly on device {tensor.device}") else: self.assertEqual(fsdp_state_dict, {}) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("fp16", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) def test_basic_save_and_load_state_dict(self, cpu_offload, fp16, state_dict_rank0_and_offload): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() full_state_dict_mgr = self._get_full_state_dict_mgr( model, state_dict_rank0_and_offload) with full_state_dict_mgr: fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() model_new.load_state_dict(fsdp_state_dict) with FullyShardedDataParallel.summon_full_params(model_new): with FullyShardedDataParallel.summon_full_params(model): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16) @skip_if_lt_x_gpu(2) @parametrize("mixed_precision", [True, False]) @parametrize("state_dict_rank0_and_offload", [True, False]) def test_save_and_load_after_forward_state_dict( self, mixed_precision, state_dict_rank0_and_offload): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) mixed_precision = MixedPrecision() if mixed_precision else None model = self._get_simple_nested_model(mixed_precision=mixed_precision) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = torch.randn(1, 10, device=torch.cuda.current_device()) output = model(*inp) loss = output.sum() expected_dtype = torch.float32 if mixed_precision is None else torch.float16 self.assertEqual(expected_dtype, loss.dtype) loss.backward() optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict fsd_mgr = self._get_full_state_dict_mgr(model, state_dict_rank0_and_offload) with fsd_mgr: state_dict = {k: v.clone() for k, v in model.state_dict().items()} self._validate_state_dict_contents(state_dict, state_dict_rank0_and_offload) _zero_model(model) # Ensure checkpointed params have the full param dtype for tensor in state_dict.values(): self.assertEqual(tensor.dtype, torch.float32) # Load state_dict into zeroed model if state_dict_rank0_and_offload: # Broadcast the state dict and move it back to GPU in # preparation for loading. state_dict = self._broadcast_state_dict(state_dict) for key in state_dict.keys(): state_dict[key] = state_dict[key].cuda() model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params) def _initialize_model(self, wrap_fsdp: bool, wrap_ddp: bool = True): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp).cuda() if wrap_fsdp: model = FSDP(model) elif wrap_ddp: model = DistributedDataParallel(model, device_ids=[self.rank]) return model @staticmethod def _state_dict(model: Module, state_dict_type: str): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict type for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.state_dict() @staticmethod def _load_state_dict(model: Module, state_dict_type: str, state_dict: Dict[str, Any]): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.load_state_dict(state_dict) def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_save_load_flow(self, state_dict_type): fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type) ddp_params = self._dist_train(wrap_fsdp=False) self.assertEqual(ddp_params, fsdp_params) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_fsdp_state_dict_keys(self, state_dict_type): state_dict = self._state_dict(self._initialize_model(True), state_dict_type) if state_dict_type == "local_state_dict": self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys()) elif state_dict_type == "state_dict": # Keys should match local model. local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) local_keys = local_model.state_dict().keys() self.assertEqual(state_dict.keys(), local_keys) else: raise NotImplementedError(f"No test for {state_dict_type}!") @skip_if_lt_x_gpu(2) @parametrize("state_dict_rank0_and_offload", [True, False]) def test_state_dict_load_into_local_module(self, state_dict_rank0_and_offload): """ Tests that FSDP's state_dict can be loaded into a local model. """ model = self._initialize_model(wrap_fsdp=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with FullyShardedDataParallel.summon_full_params(model): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return full_state_dict. sd_mgr = self._get_full_state_dict_mgr(model, state_dict_rank0_and_offload) with sd_mgr: fsdp_state_dict = model.state_dict() self._validate_state_dict_contents(fsdp_state_dict, state_dict_rank0_and_offload) # Create zeroed local model blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() # Load fsdp's full state dict into the local and verify params are as # expected. if state_dict_rank0_and_offload: # Broadcast + CUDA state_dict fsdp_state_dict = self._broadcast_state_dict(fsdp_state_dict) for key in fsdp_state_dict.keys(): fsdp_state_dict[key] = fsdp_state_dict[key].cuda() blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) @skip_if_lt_x_gpu(2) @parametrize("double_nest", [True]) def test_state_dict_skip_module(self, double_nest): torch.cuda.set_device(self.rank) def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names fsdp, linear_skip_tensor_names = _create_module() # Run a forward pass inp = torch.randn((1, 10), device=torch.cuda.current_device()) loss = fsdp(inp) loss.sum().backward() state_dict = fsdp.state_dict() if self.rank == 0: sd_keys = list(state_dict.keys()) expected = list(SkipModel(double_nest=False).state_dict().keys()) self.assertEqual(sorted(sd_keys), sorted(expected)) # TODO: parameters in linear_skip_tensor_names should not be handled # by FSDP.state_dict(). Have a check once this is implemented in # FSDP.state_dict(). # Check that it can be loaded into FSDP. new_fsdp, _ = _create_module() _zero_model(new_fsdp) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertNotEqual(p1, p2) new_fsdp.load_state_dict(deepcopy(state_dict)) for (p1, p2) in zip(fsdp.parameters(), new_fsdp.parameters()): self.assertEqual(p1, p2) # Test that the checkpoint can be loaded into a local model. local, _ = _create_module(wrap_fsdp=False) for param in local.parameters(): with torch.no_grad(): param.zero_() with fsdp.summon_full_params(fsdp): for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertNotEqual(p1, p2) local.load_state_dict(deepcopy(state_dict)) with fsdp.summon_full_params(fsdp): for (p1, p2) in zip(fsdp.parameters(), local.parameters()): self.assertEqual(p1, p2) @skip_if_lt_x_gpu(2) def test_wrong_state_dict_config(self): model = FSDP(Model(wrap_fsdp=True).cuda()) with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()): pass @skip_if_lt_x_gpu(2) def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
class TestFSDPStateDict(FSDPTest): @property def world_size(self): return 2 def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Sequential( FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs), nn.Linear(10, 10, bias=False), ), *fsdp_args, **fsdp_kwargs, ) return model def _get_simple_model(self, *fsdp_args, **fsdp_kwargs): model = FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs) return model @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("fp16", [True, False]) def test_basic_save_and_load_state_dict(self, cpu_offload, fp16): """ Tests that we can save a state_dict and load it into a blank model with various configs such as fp16 and cpu offload and parameters match as expected. """ for model_call in [ partial(self._get_simple_nested_model, cpu_offload=cpu_offload), partial(self._get_simple_model, cpu_offload=cpu_offload), ]: model = model_call() fsdp_state_dict = _get_state_dict(model, cpu_offload.offload_params, fp16) if fp16: # Verify fp16 is the type for tensor in fsdp_state_dict.values(): self.assertEqual(tensor.dtype, torch.float16) model_new = model_call() if not cpu_offload.offload_params: model_new = model_new.cuda() if fp16: model_new.half() # zero the model to ensure parameters are different. _zero_model(model_new) with model.summon_full_params(), model_new.summon_full_params(): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertNotEqual(params, params_new) # Verify parameters are the same in the new model. model_new.load_state_dict(fsdp_state_dict) with model_new.summon_full_params(): with model.summon_full_params(): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new) if fp16: for tensor in model_new.parameters(): self.assertEqual(tensor.dtype, torch.float16) @skip_if_lt_x_gpu(2) def test_save_and_load_after_forward_state_dict(self): """ Test that saving after some training results in params being updated as expected. """ torch.cuda.set_device(self.rank) model = self._get_wrapped_model(group=torch.distributed.distributed_c10d._get_default_group()) optim = torch.optim.SGD(model.parameters(), lr=0.1) initial_params = _get_full_detached_param(model) for _ in range(6): inp = model.module.get_input(torch.device("cuda")) output = model(*inp) loss = model.module.get_loss(inp, output).cuda() model.module.run_backward(loss) optim.step() trained_params = _get_full_detached_param(model) # Ensure some training occured self.assertNotEqual(initial_params, trained_params) # Save a copy of the state_dict state_dict = {k: v.clone() for k, v in model.state_dict().items()} _zero_model(model) # Load state_dict into zeroed model model.load_state_dict(state_dict) loaded_params = _get_full_detached_param(model) self.assertEqual(loaded_params, trained_params) def _initialize_model(self, wrap_fsdp: bool, wrap_ddp: bool = True): # keep everything deterministic for input data torch.manual_seed(0) model = Model(wrap_fsdp).cuda() if wrap_fsdp: model = FSDP(model) elif wrap_ddp: model = DistributedDataParallel(model, device_ids=[self.rank]) return model @staticmethod def _state_dict(model: Module, state_dict_type: str): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict type for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.state_dict() @staticmethod def _load_state_dict( model: Module, state_dict_type: str, state_dict: Dict[str, Any] ): try: enum_val = STATE_DICT_MAPPING[state_dict_type] except KeyError: raise ValueError(f"No state_dict for {state_dict_type}") with FSDP.state_dict_type(model, enum_val): return model.load_state_dict(state_dict) def _dist_train(self, wrap_fsdp: bool, state_dict_type: str = ""): # TODO: Move this test to common_fsdp. model = self._initialize_model(wrap_fsdp) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() if wrap_fsdp: blank_model = FSDP(Model(True).cuda()) _zero_model(blank_model) state_dict = self._state_dict(model, state_dict_type) self._load_state_dict(blank_model, state_dict_type, state_dict) return get_full_params(blank_model) else: return list(model.parameters()) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_state_dict_save_load_flow(self, state_dict_type): fsdp_params = self._dist_train(wrap_fsdp=True, state_dict_type=state_dict_type) ddp_params = self._dist_train(wrap_fsdp=False) self.assertEqual(ddp_params, fsdp_params) @skip_if_lt_x_gpu(2) @parametrize("state_dict_type", _SUPPORTED_STATE_DICT_IMPLS) def test_fsdp_state_dict_keys(self, state_dict_type): state_dict = self._state_dict(self._initialize_model(True), state_dict_type) if state_dict_type == "local_state_dict": self.assertEqual(set(["flat_param", "inner.flat_param"]), state_dict.keys()) elif state_dict_type == "state_dict": # Keys should match local model. local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) local_keys = local_model.state_dict().keys() self.assertEqual(state_dict.keys(), local_keys) else: raise NotImplementedError(f"No test for {state_dict_type}!") @skip_if_lt_x_gpu(2) def test_state_dict_load_into_local_module(self): """ Tests that FSDP's state_dict can be loaded into a local model. """ model = self._initialize_model(wrap_fsdp=True) optim = SGD(model.parameters(), lr=0.1) in_data = torch.rand(64, 4, requires_grad=True, device=torch.device("cuda")) for _ in range(3): out = model(in_data) out.sum().backward() optim.step() optim.zero_grad() with model.summon_full_params(): fsdp_params = deepcopy(list(model.parameters())) # get FSDP state_dict. Note that by default we return state_dict. fsdp_state_dict = model.state_dict() # Create zeroed local model blank_local_model = self._initialize_model(wrap_fsdp=False, wrap_ddp=False) for param in blank_local_model.parameters(): with torch.no_grad(): param.zero_() # Load fsdp's full state dict into the local and verify params are as # expected. blank_local_model.load_state_dict(fsdp_state_dict) local_params = list(blank_local_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
class TestParityWithDDP(FSDPTest): """ Compare losses and parameter values after several updates when using PyTorch DDP vs. FullyShardedDataParallel. """ def _get_init_modes_for_test(self, cpu_offload): modes = [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE] # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU # offload as we explicitly bring the param back to CUDA device. In # general, it will not work since we try to all_gather p.data which is # on CPU but NCCL only supports GPU. if cpu_offload.offload_params: modes.append(FSDPInitMode.CUDA_NEVER) return modes @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_nested_wrapped_model(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( NestedWrappedModule, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) @parametrize("clip_norm_type", [2.0, None]) def test_nested_all_wrapped_model(self, cpu_offload, backward_prefetch, clip_norm_type): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModule, wrap_everything=True) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, norm_type=clip_norm_type, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) @parametrize("clip_norm_type", [2.0, None]) def test_transformer_parameterized(self, cpu_offload, backward_prefetch, clip_norm_type): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( TransformerWithSharedParams, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, norm_type=clip_norm_type, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_delayed_optim_step(self, cpu_offload, backward_prefetch): # We use a model with a long CUDA delay right before the optimizer step. # This tests our streams logic, and that we don't start the allgather # until after the optimization step completes. init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_after_loss_ms=250) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_delayed_reduce_scatter(self, cpu_offload, backward_prefetch): # We insert a delay in the torch.distributed._reduce_scatter_base op, so that # the post_backward_stream takes much longer than the backward pass. # This tests that we properly block at the end of the backward pass for # the reductions to finish. init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(NestedWrappedModuleWithDelay, delay_before_reduction_ms=250) self._test_identical_outputs( model_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) def _dummy_ddp_fn(self, model): return DummyDDP(model) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) @parametrize("clip_norm_type", [2.0, None]) def test_mixture_of_experts(self, cpu_offload, backward_prefetch, clip_norm_type): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): self._test_identical_outputs( MixtureOfExperts, # MixtureOfExperts implements custom reduce logic, so the reference # behavior should use that logic instead of PyTorch DDP. ref_ddp_fn=self._dummy_ddp_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, norm_type=clip_norm_type, ) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_PRE, BackwardPrefetch.BACKWARD_POST, None]) def test_mixture_of_experts_with_delay_before_free(self, cpu_offload, backward_prefetch): init_modes = self._get_init_modes_for_test(cpu_offload) for fsdp_init_mode in init_modes: with self.subTest(fsdp_init_mode=fsdp_init_mode): model_fn = functools.partial(MixtureOfExperts, delay_before_free_ms=250) self._test_identical_outputs( model_fn, ref_ddp_fn=self._dummy_ddp_fn, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, )