def test_multiple_wrapping(self): """ This test simulates wrapping the module after training to run inference. This is required in cases where later in a session, the model is wrapped again in FSDP but contains nested FSDP wrappers within the module. """ inner_model = InnerModel() model = FSDP(inner_model).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model).cuda() rewrapped_output = rewrapped_model(input) self.assertEqual(output, rewrapped_output)
def wrap_alt(model, group=None) -> torch.nn.Module: model.block0.bias_module0 = FSDP( model.block0.bias_module0, process_group=group, ) model.block0 = FSDP(model.block0, process_group=group) return model
def test_diff_ignored_modules_across_ranks( self, pass_ignored_modules_to_root: bool): """ Tests ignoring different modules across ranks. Args: pass_ignored_modules_to_root (bool): If ``False``, does not pass any ignored modules (including those already ignored in child FSDP instances) to the root FSDP instance; if ``True``, passes all ignored modules (representing a superset of the children's ignored modules) to the root FSDP instance. """ # To exercise different `FlatParameter` enumerations across ranks, # we wrap `layer3` with FSDP, where `layer3` is registered as a module # after `layer1`, which has the variable number of ignored modules model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() layer1_ignored_modules = [ m for m in model.layer1.modules() if isinstance(m, IgnoredModule) ] model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules) model.layer3 = FSDP(model.layer3) model_ignored_modules = [ m for m in model.modules() if isinstance(m, IgnoredModule) ] if pass_ignored_modules_to_root else [] wrapped_model = FSDP(model, ignored_modules=model_ignored_modules) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_summon_full_params_respects_reshard_after_forward( self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly summon_full_params should have the same behavior with model.summon_full_params(model): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's middle linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() model.layer1[1] = FSDP(model.layer1[1]) wrapped_model = FSDP(model, ignored_modules=[model.layer1]) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations device = torch.device("cuda") optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) for _ in range(3): inp = wrapped_model.get_input(device) output = wrapped_model(*inp) loss = wrapped_model.get_loss(inp, output).to(device) wrapped_model.run_backward(loss) optim.step()
def test_no_params(self): """ Test that device_id and cpu init work if module has no params (they are effective noops, but ensure FSDP does not assume module has parameters during init) """ # Test CPU no_params = nn.ReLU() module = FSDP(no_params) # Test CUDA no_params = nn.ReLU().cuda() module = FSDP(no_params) # Test CPU + device_id no_params = nn.ReLU() module = FSDP(no_params, device_id=torch.cuda.current_device()) # For modules with no params, wrong device_id will raise error about # inconsistency between compute_device and device_id, since compute_device # is computed as torch.cuda.current_device when there are no params. no_params = nn.ReLU().cuda() context = ( self.assertRaisesRegex( AssertionError, f"Inconsistent.*cuda:{self.rank} vs cuda:0" ) ) if self.rank != 0 else suppress() with context: module = FSDP(no_params, device_id=0)
def test_fsdp_same_model_across_ranks(self): """ FSDP broadcasts model from rank 0 to ensure it starts off with the same values. """ class MyModel(nn.Module): def __init__(self, rank): super().__init__() # Seed via rank to make model different across ranks torch.manual_seed(rank) torch.cuda.manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) self.register_buffer("buffer", torch.ones(1) * rank) m = MyModel(self.rank).cuda() _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
def __init__(self, has_wrapping, sharding_strategy, mixed_precision=None): # to ensure determinism torch.manual_seed(0) torch.cuda.manual_seed(0) super().__init__() if has_wrapping: self.net = FSDP(nn.Sequential( nn.Linear(8, 16), nn.ReLU(), FSDP( nn.Linear(16, 8), device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ) ), device_id=torch.cuda.current_device(), sharding_strategy=sharding_strategy, mixed_precision=mixed_precision, ) else: self.net = nn.Sequential( nn.Linear(8, 16), nn.ReLU(), nn.Linear(16, 8) ) self.out = nn.Linear(8, 4)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer): model = FSDP( nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))).cuda(cls.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param") p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = cls.rank + 2 with model.summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback or cls.world_size == 1: # When world_size = 1, FSDP does not shard and parameter is not set to # a local shard, so write is always reflected. cls.assertEqual(p.cpu()[0], 0) else: cls.assertEqual(p.cpu()[0], cls.rank + 2)
def test_summon_full_param_writeback( self, writeback, cpu_offload, modify_outer ): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback: self.assertEqual(p.cpu()[0], 0) else: self.assertEqual(p.cpu()[0], self.rank + 2)
def test_summon_full_param_recursive(self, recurse, summon_outer): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False)) global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False)) shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size)) shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size)) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) self.assertEqual(shard_outer_numel, outer_param.numel()) self.assertEqual(shard_inner_numel, inner_param.numel()) model_to_summon = model if summon_outer else model[0] # outer is summoned if _summon_full_param is called on the outer FSDP module expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module expected_inner_numel = ( global_inner_numel if recurse or not summon_outer else shard_inner_numel ) with model_to_summon._summon_full_params(recurse=recurse): self.assertEqual(expected_outer_numel, outer_param.numel()) self.assertEqual(expected_inner_numel, inner_param.numel())
def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 1, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def create_model(with_fsdp, with_checkpoint, model_hidden_dim): torch.manual_seed(0) model = Model(model_hidden_dim, with_fsdp, with_checkpoint) if with_fsdp: model.stem = FSDP(model.stem) model.blocks = FSDP(model.blocks) model.head = FSDP(model.head) return model
def _create_model(compute_cycles, has_params: bool): model = FSDP( nn.Sequential( FSDP(Layer(compute_cycles, has_params)), FSDP(Layer(compute_cycles, has_params)), FSDP(Layer(compute_cycles, has_params)), FSDP(Layer(compute_cycles, has_params)), )).cuda() return model
def _get_simple_nested_model(self, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Sequential( FSDP(nn.Linear(10, 10, bias=False), *fsdp_args, **fsdp_kwargs), nn.Linear(10, 10, bias=False), ), *fsdp_args, **fsdp_kwargs, ) return model
def _get_simple_nested_model(self, param_dtype, *fsdp_args, **fsdp_kwargs): model = FSDP( nn.Sequential( FSDP(LinearMixedPrecision(param_dtype).cuda(), *fsdp_args, **fsdp_kwargs), LinearMixedPrecision(param_dtype).cuda(), ), *fsdp_args, **fsdp_kwargs, ) return model
def test_ignored_modules_invalid(self): """Tests that passing an FSDP module as an ignored module errors.""" model = Model() model.layer1 = FSDP(model.layer1) # Passing an FSDP module as an ignored module should error with self.assertRaises( ValueError, msg="`ignored_modules` should not include FSDP modules", ): FSDP(model, ignored_modules=[model.layer1])
def test_fsdp_device_id(self, use_index): """ If CPU module is passed into FSDP with device_id argument, it is moved to the GPU with that device_id. """ dev_id = ( torch.cuda.current_device() if use_index else torch.device("cuda", torch.cuda.current_device()) ) def _check_device_matches(fsdp, dev_id): devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) found_dev = devices.pop() if use_index and not isinstance(dev_id, torch.device): dev_id = torch.device("cuda", dev_id) self.assertEqual(found_dev, dev_id) mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, device_id=dev_id ) fsdp = FSDP(mod, device_id=dev_id) # Check FSDP parameters are moved. _check_device_matches(fsdp, dev_id) # device_id matching module device before FSDP construction # should not throw errors. mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, device_id=dev_id ) fsdp = FSDP(mod, device_id=dev_id) _check_device_matches(fsdp, dev_id) # Passing in torch.device("cuda") should work. regex = "does not have explicit index" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, device_id=torch.device("cuda") ) fsdp = FSDP(mod, device_id=torch.device("cuda")) _check_device_matches(fsdp, torch.device("cuda", torch.cuda.current_device()))
def __init__( self, group: dist.ProcessGroup, wrap_fsdp: bool, cuda_init_mode: CUDAInitMode, delay_before_free_ms: int, deterministic: bool, **fsdp_kwargs, ): super().__init__( group=group, wrap_fsdp=wrap_fsdp, cuda_init_mode=cuda_init_mode, deterministic=deterministic, ) self.group = group self.delay_before_free_ms = delay_before_free_ms self.wrap_fsdp = wrap_fsdp self.move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE if deterministic: # Give each rank different expert parameters torch.manual_seed(42 + self.rank) d_expert = 23 d_shared = 12 d_input = 8 expert = _maybe_cuda(nn.Linear(d_expert, d_shared), self.move_to_cuda) self.num_expert_params = sum([p.numel() for p in expert.parameters()]) for p in expert.parameters(): p.expert = True # type: ignore[attr-defined] if deterministic: # Keep all other parameters the same across ranks torch.manual_seed(0) shared = _maybe_cuda(nn.Linear(d_shared, d_expert), self.move_to_cuda) if wrap_fsdp: # we create a process group of size 1 for the expert params expert_group = torch.distributed.new_group( [group.rank()]) # world size 1 means no shard expert = FSDP(expert, expert_group, **fsdp_kwargs) # type: ignore[assignment] shared = FSDP(shared, group, **fsdp_kwargs) # type: ignore[assignment] self.module = nn.Sequential( _maybe_cuda(nn.Linear(d_input, d_shared), self.move_to_cuda), shared, expert, _maybe_cuda(nn.Linear(d_shared, d_input), self.move_to_cuda))
def test_fsdp_cpu_init_stays_on_cpu(self): """ Ensure that CPU model input stays on CPU after FSDP init and we log a warning. """ torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex(expected_warning=UserWarning, expected_regex=regex) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) fsdp = FSDP(mod) devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp = fsdp.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = mod.get_input(device=torch.device("cpu")) fsdp(inp[0]).sum().backward()
def test_wrong_state_dict_config(self): model = FSDP(Model(wrap_fsdp=True).cuda()) with self.assertRaisesRegex(RuntimeError, "Expected state_dict_config of type"): with model.state_dict_type(model, StateDictType.FULL_STATE_DICT, LocalStateDictConfig()): pass
def _init_model( self, nested_model: bool, sharding_strategy: ShardingStrategy, device: torch.device, ): fsdp_kwargs = {"sharding_strategy": sharding_strategy} if nested_model: model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_AFTER, fsdp_kwargs, ) fsdp_model: FSDP = FSDP( model, self.process_group, **fsdp_kwargs, ).to(device) else: fsdp_model: FSDP = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) return fsdp_model
def test_named_parameters_buffers(self, prefix: str, recurse: bool): """Tests that ``named_parameters()`` and ``named_buffers()`` for a top-level FSDP-wrapped model matches their behavior for the equivalent non-wrapped model.""" model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def init( group: dist.ProcessGroup, fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, fsdp_kwargs: Optional[Dict[str, Any]] = None, deterministic: bool = False, ): """ Initializes a :class:`NestedWrappedModule` instance, but unlike :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap policy. """ super_ = super(AlwaysWrapNestedWrappedModule, AlwaysWrapNestedWrappedModule) model = super_.init( group=group, fsdp_init_mode=FSDPInitMode.NO_FSDP, cuda_init_mode=cuda_init_mode, fsdp_kwargs=fsdp_kwargs, deterministic=deterministic, ) if fsdp_init_mode == FSDPInitMode.NO_FSDP: return model elif fsdp_init_mode == FSDPInitMode.RECURSIVE: fsdp_model = FSDP(model, auto_wrap_policy=always_wrap_policy, **fsdp_kwargs) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() return fsdp_model
def test_mixed_precision_resnet(self): """ End to end test to ensure mixed precision + auto_wrap works for ResNet model. """ resnet_model = torchvision.models.resnet50().cuda() resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm( resnet_model, process_group=dist.distributed_c10d._get_default_group()) n_bn = sum(1 if isinstance(x, _BatchNorm) else 0 for x in resnet_model.modules()) inp = torch.ones(1, 3, 1000, 1000, device='cuda') mp_config = MixedPrecision( param_dtype=torch.float16, reduce_dtype=torch.float16, buffer_dtype=torch.float16, ) fsdp = FSDP(resnet_model, auto_wrap_policy=size_based_auto_wrap_policy, mixed_precision=mp_config) # Batchnorm units should be wrapped individually. Validate this by # ensuring there are equal no. of FSDP units that are BN as BN units # in original resnet model. fsdp_bn = 0 for module in fsdp.fsdp_modules(fsdp): wrapped_module = module.module.module if isinstance(wrapped_module, _BatchNorm): fsdp_bn += 1 self.assertEqual(fsdp_bn, n_bn) # Would throw type mismatch issue without mixed precision autowrapping. loss = fsdp(inp).sum() loss.backward()
def _init_model( self, nested_model: bool, sharding_strategy: ShardingStrategy, device: torch.device, ): group = dist.distributed_c10d._get_default_group() if nested_model: model = NestedWrappedModule( group, wrap_fsdp=True, sharding_strategy=sharding_strategy, ) fsdp_model: FSDP = FSDP( model, group, sharding_strategy=sharding_strategy, ).to(device) else: fsdp_model: FSDP = self._get_wrapped_model( group, cuda_first=False, config={"sharding_strategy": sharding_strategy}, ) return fsdp_model