def test_fsdp_same_model_across_ranks(self): """ FSDP broadcasts model from rank 0 to ensure it starts off with the same values. """ class MyModel(nn.Module): def __init__(self, rank): super().__init__() # Seed via rank to make model different across ranks torch.manual_seed(rank) torch.cuda.manual_seed(rank) self.lin = nn.Linear(10, 10, bias=False) self.register_buffer("buffer", torch.ones(1) * rank) m = MyModel(self.rank).cuda() _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual) # sync_module_states also works with CPU module with device_id passed in m = MyModel(self.rank) _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual) # Passing sync_module_states into FSDP makes model the same during init. fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True) with fsdp.summon_full_params(fsdp): _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
def test_reshard_outside_forward_backward_iteration( self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 1, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model.summon_full_params( model, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_named_parameters_buffers(self, prefix: str, recurse: bool): """Tests that ``named_parameters()`` and ``named_buffers()`` for a top-level FSDP-wrapped model matches their behavior for the equivalent non-wrapped model.""" model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) model.register_buffer("buffer", torch.ones(1)) # `named_parameters()` and `named_buffers` will contain FSDP prefixes # if called on a non-FSDP root module fsdp_model = FSDP( NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ), self.process_group, ) fsdp_model.register_buffer("buffer", torch.ones(1)) with FSDP.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_params_are_unflattenned(self, rank0_only, offload_to_cpu): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) fsdp_model = FSDP(deepcopy(model)).cuda(self.rank) def _get_flat_param(): return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params(rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu): if self.rank == 0 or not rank0_only: self.assertEqual(fsdp_model.weight.shape, model.weight.shape) expected_device = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) self.assertTrue(expected_device == fsdp_model.weight.device) else: # Nonzero rank with rank0_only maintains original params. flat_within_ctx = _get_flat_param() self.assertEqual(flat_within_ctx, flattened_param) self.assertEqual(flat_within_ctx.device, torch.device(torch.cuda.current_device())) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer): model = FSDP( nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))).cuda(cls.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param") p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = cls.rank + 2 with model.summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback or cls.world_size == 1: # When world_size = 1, FSDP does not shard and parameter is not set to # a local shard, so write is always reflected. cls.assertEqual(p.cpu()[0], 0) else: cls.assertEqual(p.cpu()[0], cls.rank + 2)
def test_summon_full_params_respects_reshard_after_forward( self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False), mixed_precision=mixed_precision), nn.Linear(5, 3, bias=False), ), mixed_precision=mixed_precision, ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly summon_full_params should have the same behavior with model.summon_full_params(model): pass self.assertEqual(outer_full_param_size, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's middle linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() model.layer1[1] = FSDP(model.layer1[1]) wrapped_model = FSDP(model, ignored_modules=[model.layer1]) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations device = torch.device("cuda") optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) for _ in range(3): inp = wrapped_model.get_input(device) output = wrapped_model(*inp) loss = wrapped_model.get_loss(inp, output).to(device) wrapped_model.run_backward(loss) optim.step()
def test_params_count_and_value( self, rank0_only: bool, offload_to_cpu: bool, mixed_precision: bool, ): mixed_precision = MixedPrecision() if mixed_precision else None model = NestedWrappedModule.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) fsdp_model = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with FSDP.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters group = dist.distributed_c10d._get_default_group() wrapped_model = self._get_wrapped_model(group, ignore_modules=True) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model = self._get_nonwrapped_model(group) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations device = torch.device("cuda") optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) for _ in range(3): inp = wrapped_model.module.get_input(device) output = wrapped_model(*inp) loss = wrapped_model.module.get_loss(inp, output).to(device) wrapped_model.module.run_backward(loss) optim.step()
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) wrapped_model = FSDP( model, self.process_group, ignored_modules=[model.transformer], ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def check_weights(self, fsdp, expected_tensor_fn, check): with FSDP.summon_full_params(fsdp, recurse=True): linear_modules = [ module for module in fsdp.modules() if type(module) == nn.Linear ] for module in linear_modules: for param in module.parameters(): expected = expected_tensor_fn(param) check(param, expected, f"Got {param} but expected {expected}")
def test_params_are_unflattenned(self): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) fsdp_model = FSDP(deepcopy(model)).cuda(self.rank) flattened_param = fsdp_model.get_parameter( "_fsdp_wrapped_module.flat_param") self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params(): self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
def get_full_params(model: nn.Module, recurse: bool = True): """ Returns the full unsharded parameters of ``model``. Any FSDP-managed parameters offloaded to CPU are moved to GPU in the returned list. Args: recurse (bool): If ``False``, only unshards the parameters immediate to ``model``; if ``True``, recurses through the module hierarchy rooted at ``model``. """ with FSDP.summon_full_params(model, recurse=recurse): return deepcopy(list(model.parameters()))
def test_raises_rank0_with_writeback(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) with self.assertRaisesRegex(ValueError, "is not supported"): with fsdp_model.summon_full_params(rank0_only=True, writeback=True): pass
def _zero_model( model: nn.Module, zero_buffers: bool = False, ): """Zeros the parameters and optionally buffers of ``model`` in place.""" with FSDP.summon_full_params(model): for param in model.parameters(): with torch.no_grad(): param.zero_() if zero_buffers: for buffer in model.buffers(): with torch.no_grad(): buffer.zero_()
def test_raises_rank0_with_writeback(self): """Tests that ``summon_full_params()`` with both ``rank0_only=True`` and ``writeback=True`` raises an error.""" nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, ) with self.assertRaisesRegex(ValueError, "is not supported"): with FSDP.summon_full_params(nested_wrapped_module, rank0_only=True, writeback=True): pass
def test_summon_full_params_equivalence(self): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) with model.summon_full_params(recurse=True): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) fsdp_params = deepcopy(list(model.parameters())) self.assertEqual(fsdp_params, list(local_model.parameters()))
def test_params_count_and_value(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) with fsdp_model.summon_full_params(): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), model.module.parameters()): self.assertEqual(p1, p2)
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", model.outer.buffer: "outer.buffer", } buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd1) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict() for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) # check again just in case self.assertTrue(tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr()) self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
def test_state_dict_rank0_offload_save_load_flow(self): # Test taking checkpoint on rank 0 only, and reload # without redundant CPU memories. model = TransformerWithSharedParams( group=dist.distributed_c10d._get_default_group()) my_auto_wrap_policy = partial(transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }) model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy) ctx = self._get_state_dict_mgr(model, "state_dict", True) with ctx: state_dict = deepcopy(_get_state_dict(model)) # All ranks initialize non-FSDP model grp = dist.distributed_c10d._get_default_group() model_new = TransformerWithSharedParams(group=grp) for p in model_new.parameters(): with torch.no_grad(): p.zero_() # Only rank 0 loads the checkpoint if self.rank == 0: model_new.load_state_dict(state_dict) # TransformerWithSharedParams has a buffer of zeros, so can't pass in # self.assertNotEqual since the buffers would be equal. So just checking that # there is some difference in the model across ranks before state_dict is # broadcasted. with self.assertRaisesRegex(AssertionError, "Tensor-likes are not close"): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) # FSDP with sync_module_states=True broadcasts the checkpointed states. model_new = FSDP(model_new, device_id=torch.cuda.current_device(), auto_wrap_policy=my_auto_wrap_policy, sync_module_states=True) # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint with FSDP.summon_full_params(model_new): _validate(model_new, process_group=grp, assert_fn=self.assertEqual) with FullyShardedDataParallel.summon_full_params(model): with FullyShardedDataParallel.summon_full_params(model_new): params = list(model.parameters()) params_new = list(model_new.parameters()) self.assertEqual(params, params_new)
def test_summon_single_param(self): model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank) p = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, p.numel()) with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model.summon_full_params(model, writeback=True): self.assertEqual(1, p.numel()) with torch.no_grad(): p.copy_(torch.zeros_like(p)) # most ranks hold no data and wrote to padding so only rank zero will observe the above write if self.rank == 0: self.assertEqual(0, p[0]) else: self.assertEqual(self.rank + 2, p[0])
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP( DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload ) local_model = DeterministicModel(wrap_fsdp=False) params_to_compare = ( [p.clone() for p in model.parameters()] if rank0_only and self.rank != 0 else list(local_model.parameters()) ) writeback = not rank0_only with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=writeback, offload_to_cpu=offload_to_cpu, ): if writeback: with torch.no_grad(): for p in model.parameters(): p.add_(1) for p in params_to_compare: p.add_(1) # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare) # CPU offload is enabled for main API, so we should point back to CPU for param in model.parameters(): self.assertEqual(param.device, torch.device("cpu"))
def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's second linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() model.layer1[1] = FSDP(model.layer1[1]) wrapped_model = FSDP(model, ignored_modules=[model.layer1]) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_named_parameters_buffers(self, prefix: str, recurse: bool): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) ) fsdp_model.register_buffer("buffer", torch.ones(1)) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) model.register_buffer("buffer", torch.ones(1)) with fsdp_model.summon_full_params(fsdp_model): for call in ["named_parameters", "named_buffers"]: for (n1, p1), (n2, p2) in itertools.zip_longest( getattr(fsdp_model, call)(prefix=prefix, recurse=recurse), getattr(model, call)(prefix=prefix, recurse=recurse), ): self.assertEqual(n1, n2) self.assertEqual(p1, p2)
def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters group = dist.distributed_c10d._get_default_group() wrapped_model = self._get_wrapped_model( group, cuda_first=True, ignore_modules=True, ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model = self._get_nonwrapped_model(group) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_summon_from_non_fsdp(self): class FSDPContainer(nn.Module): def __init__(self, fsdp_1, fsdp_2, fsdp_3): super().__init__() self.fsdp_1 = fsdp_1 self.fsdp_2 = fsdp_2 self.fsdp_3 = fsdp_3 model_fsdp = FSDPContainer( FSDP(DeterministicModel(wrap_fsdp=True)), FSDP(DeterministicModel(wrap_fsdp=True)), DeterministicModel(wrap_fsdp=False), ) model_no_fsdp = FSDPContainer( DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), DeterministicModel(wrap_fsdp=False), ) params_to_compare = list(model_no_fsdp.parameters()) with FSDP.summon_full_params(model_fsdp): fsdp_params = [p.clone() for p in model_fsdp.parameters()] self.assertEqual(params_to_compare, fsdp_params)