def test_named_parameters_buffers(self, prefix: str, recurse: bool): """Tests that ``named_parameters()`` and ``named_buffers()`` for a top-level FSDP-wrapped model matches their behavior for the equivalent non-wrapped model.""" model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_diff_ignored_modules_across_ranks( self, pass_ignored_modules_to_root: bool): """ Tests ignoring different modules across ranks. Args: pass_ignored_modules_to_root (bool): If ``False``, does not pass any ignored modules (including those already ignored in child FSDP instances) to the root FSDP instance; if ``True``, passes all ignored modules (representing a superset of the children's ignored modules) to the root FSDP instance. """ # To exercise different `FlatParameter` enumerations across ranks, # we wrap `layer3` with FSDP, where `layer3` is registered as a module # after `layer1`, which has the variable number of ignored modules model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() layer1_ignored_modules = [ m for m in model.layer1.modules() if isinstance(m, IgnoredModule) ] model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules) model.layer3 = FSDP(model.layer3) model_ignored_modules = [ m for m in model.modules() if isinstance(m, IgnoredModule) ] if pass_ignored_modules_to_root else [] wrapped_model = FSDP(model, ignored_modules=model_ignored_modules) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def _get_wrapped_model( self, group, cuda_first=False, config=None, **model_kwargs, ) -> FullyShardedDataParallel: if config is None: config = {} move_to_cuda = not ("cpu_offload" in config and config["cpu_offload"].offload_params) if cuda_first: transformer = TransformerWithSharedParams(group, **model_kwargs) if move_to_cuda: transformer = transformer.cuda() model = FullyShardedDataParallel(transformer, group, **config) else: model = FullyShardedDataParallel( TransformerWithSharedParams(group, **model_kwargs), group, **config, ) if move_to_cuda: model = model.cuda() return model
def test_fsdp_same_model_across_ranks(self): """ FSDP broadcasts model from rank 0 to ensure it starts off with the same values. """ class MyModel(nn.Module): def __init__(self, rank): super().__init__() # Seed via rank to make model different across ranks torch.manual_seed(rank) torch.cuda.manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) self.register_buffer("buffer", torch.ones(1) * rank) m = MyModel(self.rank).cuda() _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
def test_input_type(self, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model()).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: self.assertTrue(input_cls is dict) in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad()
def test_fsdp_cpu_init_stays_on_cpu(self): """Tests that passing a CPU module to FSDP preserves that the wrapped module is on CPU after FSDP initialization, albeit after loging a warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, ) fsdp_model = FSDP(nested_wrapped_module, self.process_group) devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp_model = fsdp_model.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = fsdp_model.module.get_input(device=torch.device("cpu")) fsdp_model(*inp).sum().backward()
def init( group: dist.ProcessGroup, fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, fsdp_kwargs: Optional[Dict[str, Any]] = None, deterministic: bool = False, ): """ Initializes a :class:`NestedWrappedModule` instance, but unlike :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap policy. """ super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule) model = super_.init( group=group, fsdp_init_mode=FSDPInitMode.NO_FSDP, cuda_init_mode=cuda_init_mode, fsdp_kwargs=fsdp_kwargs, deterministic=deterministic, ) if fsdp_init_mode == FSDPInitMode.NO_FSDP: return model elif fsdp_init_mode == FSDPInitMode.RECURSIVE: fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() return fsdp_model
def test_summon_full_param_shard_value(self): raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank)) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flatenned param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model._summon_full_params(): self.assertEqual(raw_model_size, self.get_model_param_count(model)) all_shards = next(model.parameters()) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_summon_full_param_writeback( self, writeback, cpu_offload, modify_outer ): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback: self.assertEqual(p.cpu()[0], 0) else: self.assertEqual(p.cpu()[0], self.rank + 2)
def test_fsdp_cpu_init_stays_on_cpu(self): """ Ensure that CPU model input stays on CPU after FSDP init and we log a warning. """ torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex(expected_warning=UserWarning, expected_regex=regex) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) fsdp = FSDP(mod) devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp = fsdp.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = mod.get_input(device=torch.device("cpu")) fsdp(inp[0]).sum().backward()
def test_summon_full_params_respects_reshard_after_forward(self): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly _summon_full_params should have the same behavior with model._summon_full_params(): pass self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_wrong_state_dict_config(self): model = FSDP(Model(wrap_fsdp=True).cuda()) with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()): pass
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) wrapped_model = FSDP( model, self.process_group, ignored_modules=[model.transformer], ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flatenned param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParameter(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0 : my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu()) )
def _test_mixed_precision_embedding_table(self, mp_config): # Basic test to ensure int inputs are not casted which would break # modules such as embedding tables. param_dtype = mp_config.param_dtype or torch.float32 orig_reduce_scatter = dist._reduce_scatter_base test_reduce_scatter = partial( self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, param_dtype): # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the # entire `TransformerWithSharedParams` with a single top-level FSDP model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, {"mixed_precision": mp_config}, ) fsdp_model = FSDP(model, mixed_precision=mp_config) optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1) for _ in range(6): inp = fsdp_model.module.get_input(torch.device("cuda")) # This would fail if we casted integer module inputs such as for # embedding tables. output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).cuda() self.assertEqual(loss.dtype, param_dtype) fsdp_model.module.run_backward(loss) optim.step()
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer): model = FSDP( nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))).cuda(cls.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param") p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = cls.rank + 2 with model.summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback or cls.world_size == 1: # When world_size = 1, FSDP does not shard and parameter is not set to # a local shard, so write is always reflected. cls.assertEqual(p.cpu()[0], 0) else: cls.assertEqual(p.cpu()[0], cls.rank + 2)
def test_params_are_unflattenned(self, rank0_only, offload_to_cpu): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) fsdp_model = FSDP(deepcopy(model)).cuda(self.rank) def _get_flat_param(): return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params(rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu): if self.rank == 0 or not rank0_only: self.assertEqual(fsdp_model.weight.shape, model.weight.shape) expected_device = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) self.assertTrue(expected_device == fsdp_model.weight.device) else: # Nonzero rank with rank0_only maintains original params. flat_within_ctx = _get_flat_param() self.assertEqual(flat_within_ctx, flattened_param) self.assertEqual(flat_within_ctx.device, torch.device(torch.cuda.current_device())) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_summon_full_param_recursive(self, recurse, summon_outer): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False)) global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False)) shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) model_to_summon = model if summon_outer else model[0] # outer is summoned if _summon_full_param is called on the outer FSDP module expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module expected_inner_numel = ( global_inner_numel if recurse or not summon_outer else shard_inner_numel ) with model_to_summon._summon_full_params(recurse=recurse): self.assertEqual(expected_outer_numel, outer_param.numel()) self.assertEqual(expected_inner_numel, inner_param.numel())
def test_mixed_precision_resnet(self): """ End to end test to ensure mixed precision + auto_wrap works for ResNet model. """ resnet_model = torchvision.models.resnet50().cuda() resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( resnet_model, process_group=dist.distributed_c10d._get_default_group()) n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()) inp = torch.ones(1, 3, 1000, 1000, device='cuda') mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) fsdp = FSDP(resnet_model, auto_wrap_policy=size_based_auto_wrap_policy, mixed_precision=mp_config) # Batchnorm units should be wrapped individually. Validate this by # ensuring there are equal no. of FSDP units that are BN as BN units # in original resnet model. fsdp_bn = 0 for module in fsdp.fsdp_modules(fsdp): wrapped_module = module.module.module if isinstance(wrapped_module, _BatchNorm): fsdp_bn += 1 self.assertEqual(fsdp_bn, n_bn) # Would throw type mismatch issue without mixed precision autowrapping. loss = fsdp(inp).sum() loss.backward()
def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 1, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_optim_state_dict_nested( self, state_dict_type: StateDictType, use_multiple_param_groups: bool, rank0_only: bool, use_diff_optim_inputs: bool, ) -> None: """ Tests :meth:`full_optim_state_dict` and `sharded_optim_state_dict` by comparing the returned dict for an FSDP-wrapped model with that of an equivalent non-wrapped model. The test checks the equivalence excluding the parameter keys since the FSDP and normal optimizer state dicts key by names and IDs, respectively. This means that the test can pass even if parameter keys are incorrectly mapped to values. Their correct mapping is tested in other tests that exercise the save/load workflow. """ if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT: return # not supported NUM_ITERS = 3 model1, optim1, optim_input = self._init_nested_model( wrap=True, use_multiple_param_groups=use_multiple_param_groups, use_diff_optim_inputs=use_diff_optim_inputs, ) losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS) if state_dict_type == StateDictType.FULL_STATE_DICT: fsdp_osd = FSDP.full_optim_state_dict( model1, optim1, optim_input, rank0_only=rank0_only, ) else: fsdp_osd = FSDP.sharded_optim_state_dict( model1, optim1, optim_input ) # Non-target ranks get an empty state dict if rank0_only and self.rank != 0: self.assertEqual(len(fsdp_osd), 0) return model2, optim2, _ = self._init_nested_model( wrap=False, use_multiple_param_groups=use_multiple_param_groups, use_diff_optim_inputs=use_diff_optim_inputs, ) losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS) ref_osd = optim2.state_dict() # Check the losses to eliminate model drift as a source of error for i, (l1, l2) in enumerate(zip(losses1, losses2)): assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}" # Do not check the parameter keys since the full/sharded optimizer state # dict uses parameter names, while the non-wrapped equivalent uses # parameter IDs check_same_param_keys = False self._check_same_param_groups( fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys, ) self._check_same_state( fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys, )
def wrap(sharding_strategy: ShardingStrategy, device: torch.device, init_policy=always_wrap_policy): model = Model() wrap_policy = ParamExecOrderWrapPolicy(init_policy=init_policy) fsdp_model = FSDP(model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy) return fsdp_model.to(device)
def test_rekey_optim_state_dict_to_names( self, use_multiple_param_groups: bool, ): """Tests :meth:`rekey_optim_state_dict` with the new keys being parameter names by checking that a non-wrapped model (i.e. without FSDP modules) can rekey its optimizer state dict to match the expected output of :meth:`full_optim_state_dict`, hence be sharded using :meth:`shard_full_optim_state_dict`, and finally match the per-rank optimizer state dict of a wrapped model (i.e. with FSDP modules).""" NUM_ITERS = 3 # Run a wrapped model for a few iterations model1, optim1, optim_input1 = self._init_nested_model( wrap=True, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model1, optim1, num_iters=NUM_ITERS) # Run a non-wrapped model for a few iterations model2, optim2, optim_input2 = self._init_nested_model( wrap=False, use_multiple_param_groups=use_multiple_param_groups, ) self._step_model(model2, optim2, num_iters=NUM_ITERS) # Re-key the non-wrapped model's optimizer state dict using parameter # names (still according to itself) osd2 = optim2.state_dict() rekeyed_osd = FSDP.rekey_optim_state_dict( osd2, OptimStateKeyType.PARAM_NAME, model2, optim_input2, ) # Shard the non-wrapped model's re-keyed optimizer state dict, which # maps back to (flattened) parameter IDs sharded_osd = FSDP.shard_full_optim_state_dict( rekeyed_osd, model1, optim_input1, ) # Check that this sharded optimizer state dict matches the wrapped # model's per-rank optimizer state dict osd1 = optim1.state_dict() check_same_param_keys = True self._check_same_param_groups( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) self._check_same_state( sharded_osd, osd1, check_same_param_keys=check_same_param_keys, ) # As a sanity check, check that we can load and run a few iterations optim1.load_state_dict(sharded_osd) self._step_model(model1, optim1, num_iters=NUM_ITERS)
def wrap( sharding_strategy: ShardingStrategy, device: torch.device, wrap_policy: Callable, ) -> torch.nn.Module: model = Model() fsdp_model = FSDP(model, auto_wrap_policy=wrap_policy, sharding_strategy=sharding_strategy) return fsdp_model.to(device)
def test_params_are_unflatenned(self): model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank) flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, flattened_param.numel()) with model._summon_full_params(): a = model.weight.flatten().detach() b = flattened_param.detach() self.assertTrue(torch.equal(a, b))
def __init__(self, wrap_fsdp, register_buffers=False): super().__init__() self.inner = Linear(*INNER_SHAPE) if register_buffers: self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE)) if wrap_fsdp: self.inner = FSDP(self.inner) self.outer = Linear(*OUTER_SHAPE) if register_buffers: self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))