def test_summon_full_param_writeback( self, writeback, cpu_offload, modify_outer ): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback: self.assertEqual(p.cpu()[0], 0) else: self.assertEqual(p.cpu()[0], self.rank + 2)
def test_summon_full_param_shard_value(self): raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank)) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flatenned param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model._summon_full_params(): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParameter(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0 : my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu()) )
def test_summon_full_params_respects_reshard_after_forward(self): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False) ) ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # trigger lazy init model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # similarly _summon_full_params should have the same behavior with model._summon_full_params(): pass self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload, modify_outer): model = FSDP( nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False))).cuda(cls.rank) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param") p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = cls.rank + 2 with model._summon_full_params(writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback or cls.world_size == 1: # When world_size = 1, FSDP does not shard and parameter is not set to # a local shard, so write is always reflected. cls.assertEqual(p.cpu()[0], 0) else: cls.assertEqual(p.cpu()[0], cls.rank + 2)
def test_params_are_unflatenned(self): model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank) flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, flattened_param.numel()) with model._summon_full_params(): a = model.weight.flatten().detach() b = flattened_param.detach() self.assertTrue(torch.equal(a, b))
def test_reshard_outside_forward_backward_iteration(self): model = FSDP( nn.Sequential( FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 1, bias=False) ) ).cuda(self.rank) outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param" ) outer_full_param_size = outer_param.numel() * self.world_size # First lets validate our assumption about resharding output = model(torch.zeros(5).cuda(self.rank)) # the root FSDP module keeps all params around self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() # we reshard everything after backward() finishes self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size()) # now lets repeat it with summon done in between output = model(torch.zeros(5).cuda(self.rank)) with model._summon_full_params(): pass self.assertEqual( outer_full_param_size, outer_param._full_param_padded.storage().size() ) self.assertEqual(0, inner_param._full_param_padded.storage().size()) output.backward() with model._summon_full_params(): pass self.assertEqual(0, outer_param._full_param_padded.storage().size()) self.assertEqual(0, inner_param._full_param_padded.storage().size())
def test_params_are_unflattenned(self): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) fsdp_model = FSDP(deepcopy(model)).cuda(self.rank) flattened_param = fsdp_model.get_parameter( "_fsdp_wrapped_module.flat_param") self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model._summon_full_params(): self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
def test_params_count_and_value(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) with fsdp_model._summon_full_params(): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), model.module.parameters()): self.assertEqual(p1, p2)
def test_summon_single_param(self): model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank) p = model.get_parameter("_fsdp_wrapped_module.flat_param") self.assertEqual(1, p.numel()) with torch.no_grad(): # This sets the local shard value p[0] = self.rank + 2 with model._summon_full_params(writeback=True): self.assertEqual(1, p.numel()) with torch.no_grad(): p.copy_(torch.zeros_like(p)) # most ranks hold no data and wrote to padding so only rank zero will observe the above write if self.rank == 0: self.assertEqual(0, p[0]) else: self.assertEqual(self.rank + 2, p[0])
def _zero_model(fsdp_model: FullyShardedDataParallel): with fsdp_model._summon_full_params(): for param in fsdp_model.parameters(): with torch.no_grad(): param.zero_()
def _get_full_detached_param(fsdp_model: FullyShardedDataParallel): with fsdp_model._summon_full_params(): params = list(p.clone().detach_() for p in fsdp_model.parameters()) return params