def test_params_count_and_value(self, rank0_only, offload_to_cpu, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, mixed_precision=mixed_precision, ), mixed_precision=mixed_precision, ) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([p.to(dev) for p in model.module.parameters()] if not rank0_only or self.rank == 0 else list( p.clone() for p in fsdp_model.parameters())) with fsdp_model.summon_full_params(fsdp_model, rank0_only=rank0_only, writeback=not rank0_only): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), params_to_compare): self.assertEqual(p1, p2) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def test_summon_full_param_shard_value(self): raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank)) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flatenned param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model._summon_full_params(): self.assertEqual(raw_model_size, self.get_model_param_count(model)) all_shards = next(model.parameters()) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_summon_full_param_shard_value(self, mixed_precision): mixed_precision = MixedPrecision() if mixed_precision else None raw_model = nn.Linear(10, 11) raw_model_size = self.get_model_param_count(raw_model) expected_shard_size = self.get_expected_sharded_size(raw_model_size) model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision) self.assertEqual(expected_shard_size, self.get_model_param_count(model)) # we're assuming a single flattened param self.assertEqual(1, len(list(model.parameters()))) my_shard = torch.clone(next(model.parameters())) with model.summon_full_params(model): self.assertEqual(raw_model_size, self.get_model_param_count(model)) parameters = list(model.parameters()) all_shards = FlatParamHandle.flatten_params(parameters, requires_grad=False) my_slice = torch.chunk(all_shards, self.world_size)[self.rank] # shards are padded but the full_param tensor is not a, b = my_shard[0:my_slice.numel()], my_slice self.assertTrue( torch.equal(my_shard[0:my_slice.numel()].cpu(), my_slice.cpu()))
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) dev = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) params_to_compare = ([ p.clone() for p in model.parameters() ] if rank0_only and self.rank != 0 else list(local_model.parameters())) with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu, ): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare)
def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's middle linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() model.layer1[1] = FSDP(model.layer1[1]) wrapped_model = FSDP(model, ignored_modules=[model.layer1]) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations device = torch.device("cuda") optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) for _ in range(3): inp = wrapped_model.get_input(device) output = wrapped_model(*inp) loss = wrapped_model.get_loss(inp, output).to(device) wrapped_model.run_backward(loss) optim.step()
def test_fsdp_cpu_init_stays_on_cpu(self): """Tests that passing a CPU module to FSDP preserves that the wrapped module is on CPU after FSDP initialization, albeit after loging a warning, and that FSDP moves CPU input to GPU before the forward.""" torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex( expected_warning=UserWarning, expected_regex=regex ) with context: nested_wrapped_module = NestedWrappedModule.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_NEVER, ) fsdp_model = FSDP(nested_wrapped_module, self.process_group) devices = {p.device for p in fsdp_model.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp_model = fsdp_model.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = fsdp_model.module.get_input(device=torch.device("cpu")) fsdp_model(*inp).sum().backward()
def test_input_type(self, input_cls): """Test FSDP with input being a list or a dict, only single GPU.""" class Model(Module): def __init__(self): super().__init__() self.layer = Linear(4, 4) def forward(self, input): if isinstance(input, list): input = input[0] else: assert isinstance(input, dict), input input = input["in"] return self.layer(input) model = FSDP(Model()).cuda() optim = SGD(model.parameters(), lr=0.1) for _ in range(5): in_data = torch.rand(64, 4).cuda() in_data.requires_grad = True if input_cls is list: in_data = [in_data] else: self.assertTrue(input_cls is dict) in_data = {"in": in_data} out = model(in_data) out.sum().backward() optim.step() optim.zero_grad()
def test_diff_ignored_modules_across_ranks( self, pass_ignored_modules_to_root: bool): """ Tests ignoring different modules across ranks. Args: pass_ignored_modules_to_root (bool): If ``False``, does not pass any ignored modules (including those already ignored in child FSDP instances) to the root FSDP instance; if ``True``, passes all ignored modules (representing a superset of the children's ignored modules) to the root FSDP instance. """ # To exercise different `FlatParameter` enumerations across ranks, # we wrap `layer3` with FSDP, where `layer3` is registered as a module # after `layer1`, which has the variable number of ignored modules model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda() layer1_ignored_modules = [ m for m in model.layer1.modules() if isinstance(m, IgnoredModule) ] model.layer1 = FSDP(model.layer1, ignored_modules=layer1_ignored_modules) model.layer3 = FSDP(model.layer3) model_ignored_modules = [ m for m in model.modules() if isinstance(m, IgnoredModule) ] if pass_ignored_modules_to_root else [] wrapped_model = FSDP(model, ignored_modules=model_ignored_modules) optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_fsdp_cpu_init_stays_on_cpu(self): """ Ensure that CPU model input stays on CPU after FSDP init and we log a warning. """ torch.cuda.set_device(self.rank) regex = "Module is put on CPU" context = self.assertWarnsRegex(expected_warning=UserWarning, expected_regex=regex) with context: mod = NestedWrappedModule( group=self.process_group, wrap_fsdp=True, wrap_everything=True, fsdp_init_mode=FSDPInitMode.CUDA_NEVER, ) fsdp = FSDP(mod) devices = {p.device for p in fsdp.parameters()} self.assertEqual(1, len(devices)) self.assertEqual(torch.device("cpu"), devices.pop()) fsdp = fsdp.cuda() # Ensure fwd + backward can be performed after moving to CUDA. # CPU input also tests that input is correctly moved to appropriate # CUDA device. inp = mod.get_input(device=torch.device("cpu")) fsdp(inp[0]).sum().backward()
def _test_mixed_precision_embedding_table(self, mp_config): # Basic test to ensure int inputs are not casted which would break # modules such as embedding tables. param_dtype = mp_config.param_dtype or torch.float32 orig_reduce_scatter = dist._reduce_scatter_base test_reduce_scatter = partial( self._reduce_scatter_base_validate_mp, orig_reduce_scatter, mp_config, ) with patch_reduce_scatter(test_reduce_scatter, param_dtype): # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the # entire `TransformerWithSharedParams` with a single top-level FSDP model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, {"mixed_precision": mp_config}, ) fsdp_model = FSDP(model, mixed_precision=mp_config) optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1) for _ in range(6): inp = fsdp_model.module.get_input(torch.device("cuda")) # This would fail if we casted integer module inputs such as for # embedding tables. output = fsdp_model(*inp) loss = fsdp_model.module.get_loss(inp, output).cuda() self.assertEqual(loss.dtype, param_dtype) fsdp_model.module.run_backward(loss) optim.step()
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module model = Model(wrap_fsdp=True).cuda() ignored_modules = [model.outer] ignored_param_to_param_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters are not cloned for param, param_name in ignored_param_to_param_name.items(): self.assertTrue(param_name in sd) self.assertEqual(param.data_ptr(), sd[param_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param)
def test_ignored_modules_transformer(self): """Tests that ignored modules' parameters are not flattened for a transformer model with shared parameters.""" # Initialize an FSDP-wrapped transformer model that has FSDP ignore # the `nn.Transformer` module's parameters model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) wrapped_model = FSDP( model, self.process_group, ignored_modules=[model.transformer], ) # Check that the wrapped model's flattened parameter does not include # the ignored transformer module's parameters nonwrapped_model: nn.Module = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, ) total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.transformer.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_multiple_wrapping(self): """ This test simulates wrapping the module after training to run inference. This is required in cases where later in a session, the model is wrapped again in FSDP but contains nested FSDP wrappers within the module. """ inner_model = InnerModel() model = FSDP(inner_model).cuda() optim = SGD(model.parameters(), lr=0.1) for i in range(3): input = torch.rand((1, 5), dtype=torch.float).cuda() input.requires_grad = True output = model(input) output.sum().backward() optim.step() optim.zero_grad() input = torch.rand((1, 5), dtype=torch.float).cuda() output = model(input) # second time to rewrap the inner model rewrapped_model = FSDP(inner_model).cuda() rewrapped_output = rewrapped_model(input) self.assertEqual(output, rewrapped_output)
def test_params_are_unflattenned(self, rank0_only, offload_to_cpu): layer_shape = (10, 12) model = nn.Linear(*layer_shape, bias=False).cuda(self.rank) fsdp_model = FSDP(deepcopy(model)).cuda(self.rank) def _get_flat_param(): return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param") flattened_param = _get_flat_param() self.assertEqual(layer_shape[0] * layer_shape[1] / 2, flattened_param.numel()) with fsdp_model.summon_full_params(rank0_only=rank0_only, writeback=not rank0_only, offload_to_cpu=offload_to_cpu): if self.rank == 0 or not rank0_only: self.assertEqual(fsdp_model.weight.shape, model.weight.shape) expected_device = (torch.device("cpu") if offload_to_cpu else torch.device( "cuda", torch.cuda.current_device())) self.assertTrue(expected_device == fsdp_model.weight.device) else: # Nonzero rank with rank0_only maintains original params. flat_within_ctx = _get_flat_param() self.assertEqual(flat_within_ctx, flattened_param) self.assertEqual(flat_within_ctx.device, torch.device(torch.cuda.current_device())) # CPU offload should restore the param device param = next(fsdp_model.parameters()) self.assertTrue( param.device == torch.device("cuda", torch.cuda.current_device()))
def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None): if auto_wrap: module = meta_module_fn() is_meta = next(module.parameters()).is_meta fsdp_meta = FSDP( module, auto_wrap_policy=always_wrap, param_init_fn=init_fn, ) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) module_regular = NestedModel(device="cuda") _reset_params_if_meta(is_meta, module_regular) fsdp_regular = FSDP( module_regular, auto_wrap_policy=always_wrap, ) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) else: with enable_wrap( wrapper_cls=FSDP, param_init_fn=init_fn, ): module = meta_module_fn() is_meta = next(module.parameters()).is_meta # Non FSDP modules will still be initialized because they bubble up # to be part of a larger FSDP unit. fsdp_meta = wrap(module) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) # Init and reset parameters before wrapping so that reset_params # matches up with meta device's initialization. module_regular = NestedModel(device="cuda") _reset_params_if_meta(is_meta, module_regular) with enable_wrap(wrapper_cls=FSDP): module_regular.lin1 = wrap(module_regular.lin1) module_regular.l3 = wrap(module_regular.l3) fsdp_regular = wrap(module_regular) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) # Compare it before training self._compare_fsdp(fsdp_meta, fsdp_regular) inp = torch.randn(10, 2, device='cuda') fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() regular_opt.step() self._compare_fsdp(fsdp_meta, fsdp_regular)
def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None): # Create model on meta device and wrap with FSDP. model = meta_module_fn() is_meta = next(model.parameters()).is_meta fsdp_meta = FSDP( model, auto_wrap_policy=always_wrap, param_init_fn=init_fn, ) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) # Test to make sure it is the same model parameters as regular FSDP # approach. regular = MyModel(device="cuda") _reset_params_if_meta(is_meta, regular) fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) self._compare_fsdp(fsdp_meta, fsdp_regular) inp = torch.randn(10, 2, device='cuda') fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() regular_opt.step() self._compare_fsdp(fsdp_meta, fsdp_regular) # Test that meta init works if all submodules are contained in only a # single FSDP unit. model = meta_module_fn() fsdp_meta = FSDP(model, param_init_fn=init_fn) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) regular = MyModel(device="cuda") _reset_params_if_meta(is_meta, regular) fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) # Run a forward + backward pass + optimizer step fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() regular_opt.step() self._compare_fsdp(fsdp_meta, fsdp_regular)
def _dist_train(self, with_nested_trunk, freezing_method, freeze_after_wrap_fsdp, with_fsdp): torch.manual_seed(0) batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = self._create_model(with_fsdp, with_nested_trunk, freeze_after_wrap_fsdp) model = model.cuda() # freezing the trunk using requires_grad. if freezing_method == FreezingMethod.RequiresGrad: for param in model.trunk.parameters(): param.requires_grad = False if with_fsdp: if not freeze_after_wrap_fsdp: model.fsdp_wrap() model = FSDP(model) else: model = DistributedDataParallel(model, device_ids=[self.rank]) target = torch.tensor([0, 1], dtype=torch.long).cuda() criterion = nn.CrossEntropyLoss() optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9) for iteration in range(3): out = model(batch) fake_loss = criterion(out, target) optimizer.zero_grad() fake_loss.backward() if freezing_method == FreezingMethod.GradToNone: if with_fsdp: for param in model.module.module.trunk.parameters(): param.grad = None else: for param in model.module.trunk.parameters(): param.grad = None optimizer.step() if with_fsdp: get_full_params(model) return list(model.parameters())
def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu): offload = CPUOffload(offload_params=True) model = FSDP( DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload ) local_model = DeterministicModel(wrap_fsdp=False) params_to_compare = ( [p.clone() for p in model.parameters()] if rank0_only and self.rank != 0 else list(local_model.parameters()) ) writeback = not rank0_only with model.summon_full_params( model, recurse=True, rank0_only=rank0_only, writeback=writeback, offload_to_cpu=offload_to_cpu, ): if writeback: with torch.no_grad(): for p in model.parameters(): p.add_(1) for p in params_to_compare: p.add_(1) # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) # FSDP param deepcopy() of params has issues fsdp_params = [p.clone() for p in model.parameters()] self.assertEqual(fsdp_params, params_to_compare) # CPU offload is enabled for main API, so we should point back to CPU for param in model.parameters(): self.assertEqual(param.device, torch.device("cpu"))
def test_summon_full_params_equivalence(self): offload = CPUOffload(offload_params=True) model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload) local_model = DeterministicModel(wrap_fsdp=False) with model.summon_full_params(recurse=True): # Below sleep causes failures without stream synchronization in # summon_full_params fix. torch.cuda._sleep(1000000) fsdp_params = deepcopy(list(model.parameters())) self.assertEqual(fsdp_params, list(local_model.parameters()))
def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations): gpu_id = self.rank world_size = self.world_size batch = torch.randn(size=(2, 3, 224, 224)).cuda() model = create_model( with_fsdp=True, with_checkpoint=with_checkpoint, model_hidden_dim=model_hidden_dim, ) model = model.cuda() model = FSDP(model) # We enable momentum so that after the first iteration, the optimizer state is added # to the total memory used. criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9) results = {} # results of memory stats for iteration in range(iterations): get_cur_mem(gpu_id, results, f"iter {iteration}: start") out = model(batch) get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd") out = sum(o.sum() for o in out[0]) fake_loss = criterion(out, torch.tensor(0.0).cuda()) get_cur_mem(gpu_id, results, f"iter {iteration}: after loss") fake_loss.backward() get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd") optimizer.step() get_cur_mem(gpu_id, results, f"iter {iteration}: after step") # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory. model.zero_grad(set_to_none=True) get_cur_mem(gpu_id, results, f"iter {iteration}: done") def cmp(results, expected): ret = "" self.assertEqual(results.keys(), expected.keys()) for k, v in results.items(): exp = expected[k] if abs(exp - v) > 1: # allow 1MB rounding differences ret += f"{k}: got {v}, expected {exp}\n" return ret output = cmp(results, expected) self.assertEqual(output, "")
def test_params_count_and_value(self): fsdp_model = FSDP( NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=True, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, )) model = NestedWrappedModule( group=dist.distributed_c10d._get_default_group(), wrap_fsdp=False, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, ) with fsdp_model.summon_full_params(): for p1, p2 in itertools.zip_longest(fsdp_model.parameters(), model.module.parameters()): self.assertEqual(p1, p2)
def test_state_dict_with_ignored_modules(self): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", model.outer.buffer: "outer.buffer", } buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict() with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr()) # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() nonwrapped_model.load_state_dict(sd1) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict() for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): self.assertTrue(tensor_name in sd1) # check again just in case self.assertTrue(tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr()) self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
def test_ignored_modules_nested(self): """Tests that passing a module with nested FSDP modules does not error and still ignores non-FSDP modules' parameters.""" # Initialize an FSDP-wrapped nested model that first wraps the nested # sequential's second linear layer (`layer1[1]`) and then wraps the # overall model while ignoring the nested sequential (`layer1`) model = Model().cuda() model.layer1[1] = FSDP(model.layer1[1]) wrapped_model = FSDP(model, ignored_modules=[model.layer1]) # Check that the wrapped model's flattened parameter does not include # the ignored nested sequential's parameters nonwrapped_model = Model() total_numel = sum(p.numel() for p in nonwrapped_model.parameters()) ignored_numel = sum(p.numel() for p in nonwrapped_model.layer1.parameters()) nonignored_numel = total_numel - ignored_numel with FSDP.summon_full_params(wrapped_model): flat_param_numel = wrapped_model.params[0].numel() self.assertEqual(flat_param_numel, nonignored_numel) # Check that we can run a few iterations optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3) self._train_model(wrapped_model, optim, 3)
def test_state_dict_with_ignored_modules(self, prefix, ignore_inner): # Initialize an FSDP-wrapped model with an ignored module that includes # both parameters and a buffer model = Model(wrap_fsdp=True, register_buffers=True, ignore_inner=ignore_inner).cuda() ignored_modules = [model.outer] ignored_tensor_to_tensor_name = { model.outer.bias: "outer.bias", model.outer.weight: "outer.weight", } if ignore_inner: ignored_tensor_to_tensor_name = { **ignored_tensor_to_tensor_name, model.inner.bias: "inner.bias", model.inner.weight: "inner.weight", } # Note that when model.inner is not ignored this test also ensures # non-ignored buffers are not cloned. buffer_to_buffer_name = { model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer", } fsdp_model = FSDP(model, ignored_modules=ignored_modules) prefix_str = "foo." if prefix else "" with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd1 = fsdp_model.state_dict(prefix=prefix_str) with FSDP.summon_full_params(fsdp_model): fsdp_params = deepcopy(list(fsdp_model.parameters())) # Check that the ignored parameters and all buffers are not cloned for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd1) self.assertEqual(tensor.data_ptr(), sd1[prefixed_tensor_name].data_ptr(), f"{prefixed_tensor_name}") # Check that the state dict can be loaded into a non-wrapped version of # the model nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda() for param in nonwrapped_model.parameters(): with torch.no_grad(): param.zero_() to_load = {k[len(prefix_str):]: v for k, v in sd1.items()} nonwrapped_model.load_state_dict(to_load, strict=True) local_params = list(nonwrapped_model.parameters()) for fsdp_param, local_param in zip(fsdp_params, local_params): self.assertEqual(fsdp_param, local_param) # Check that if we save a state dict again, the ignored parameters and # buffer still have the same data pointer with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT): sd2 = fsdp_model.state_dict(prefix=prefix_str) for tensor, tensor_name in { **ignored_tensor_to_tensor_name, **buffer_to_buffer_name, }.items(): prefixed_tensor_name = f"{prefix_str}{tensor_name}" self.assertTrue(prefixed_tensor_name in sd2) self.assertEqual(tensor.data_ptr(), sd2[prefixed_tensor_name].data_ptr()) self.assertEqual(sd1[prefixed_tensor_name].data_ptr(), sd2[prefixed_tensor_name].data_ptr())
def test_state_dict_rank0_offload_save_load_flow(self): """Tests saving a model checkpoint only on rank 0 and loading it only on rank 0 with ``sync_module_states=True`` to emulate the workflow to avoid redundant CPU memory usage.""" auto_wrap_policy = partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) # Force model parameters and buffers to be nonzero with FSDP.summon_full_params(fsdp_model): for tensor in itertools.chain(fsdp_model.parameters(), fsdp_model.buffers()): if torch.count_nonzero(tensor) == 0: with torch.no_grad(): tensor.add_( torch.tensor(1, dtype=tensor.dtype, device=tensor.device)) with self._get_state_dict_mgr(fsdp_model, "state_dict", True): state_dict = deepcopy(_get_state_dict(fsdp_model)) # Initialize a non-wrapped model on all ranks new_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, ) _zero_model(new_model, zero_buffers=True) # Only load the checkpoint on rank 0 if self.rank == 0: new_model.load_state_dict(state_dict, strict=True) _assert_module_states( new_model, process_group=self.process_group, assert_fn=self.assertNotEqual, ) # Broadcast the module states from rank 0 with `sync_module_states=True` new_fsdp_model = FSDP( new_model, device_id=torch.cuda.current_device(), auto_wrap_policy=auto_wrap_policy, sync_module_states=True, ) # Check FSDP models are equal across ranks with FSDP.summon_full_params(new_fsdp_model): _assert_module_states( new_fsdp_model, process_group=self.process_group, assert_fn=self.assertEqual, ) # Check FSDP models correctly loaded the checkpoint with FullyShardedDataParallel.summon_full_params(fsdp_model): with FullyShardedDataParallel.summon_full_params(new_fsdp_model): params = list(fsdp_model.parameters()) params_new = list(new_fsdp_model.parameters()) self.assertEqual(params, params_new)
def _test_fsdp_parity( self, model_class: Type[FSDPTestModel], fsdp_init_mode: FSDPInitMode, cuda_init_mode: CUDAInitMode, ref_init_fn: Optional[Callable] = None, num_iters: int = 2, save_model: bool = True, cpu_offload: CPUOffload = CPUOffload(), backward_prefetch: Optional[BackwardPrefetch] = None, forward_prefetch: bool = False, sharding_strategy: Optional[ShardingStrategy] = None, mixed_precision: Optional[MixedPrecision] = None, enable_sharded_grad_scaler: bool = False, use_pure_fp16: bool = False, norm_type: Optional[Union[float, int]] = None, init_kwargs: Optional[Dict[str, Any]] = None, **fsdp_kwargs, ): """ Tests FSDP training against a reference, which defaults to DDP but may be customized with ``ref_init_fn``. Args: model_class (Type[FSDPTestModel]): A model class that inherits from ``FSDPTestModel``, which defines the expected interface. fsdp_init_mode (FSDPInitMode): The mode to initialize the FSDP-wrapped model. This should not be ``NO_FSDP``. ref_init_fn (Optional[Callable]): A callable to invoke that wraps a non-wrapped model to construct the reference model, where this wrapper should provide data parallel semantics. If ``None``, then the callable defaults to the DDP constructor. """ assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP" if init_kwargs is None: init_kwargs = {} lr = 1e-2 rank = self.process_group.rank() # Establish reference behavior with DDP model = model_class.init( self.process_group, FSDPInitMode.NO_FSDP, CUDAInitMode.CUDA_BEFORE, deterministic=True, **init_kwargs, ) if ref_init_fn is None: ref_model = DDP(model, device_ids=[rank], output_device=rank) else: ref_model = ref_init_fn(model) if use_pure_fp16: ref_model = ref_model.half() ref_loss = self._train_for_several_steps( ref_model, num_iters, autocast=mixed_precision is not None, lr=lr, fsdp_cpu_offload=cpu_offload, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) ddp_params = list(ref_model.parameters()) # Check against FSDP behavior fsdp_kwargs.update({ "cpu_offload": cpu_offload, "backward_prefetch": backward_prefetch, "forward_prefetch": forward_prefetch, "sharding_strategy": sharding_strategy, "mixed_precision": mixed_precision, }) try: fsdp_model = model_class.init( self.process_group, fsdp_init_mode, cuda_init_mode, fsdp_kwargs, deterministic=True, **init_kwargs, ) except Exception as e: raise ValueError( f"Initializing {model_class} raised error {str(e)}") if not isinstance(fsdp_model, FSDP): # Enforce that we wrap with top-level FSDP since we are comparing # assuming a data parallel reference and some test models may not # do so in their `init()` method fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs) if use_pure_fp16: # Change the model parameter dtype after FSDP initialization fsdp_model = fsdp_model.half() if cuda_init_mode == CUDAInitMode.CUDA_AFTER: fsdp_model = fsdp_model.cuda() offload_params = cpu_offload is not None and cpu_offload.offload_params # Offloading parameters with `CUDA_AFTER` should raise an error during # lazy initialization due to the parameter devices not being CPU; # otherwise, all parameter devices should be CPU expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER if expects_cpu_device: cpu_device = torch.device("cpu") for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) context = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if expects_device_error else suppress()) with context: fsdp_loss = self._train_for_several_steps( fsdp_model, num_iters, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, mixed_precision=mixed_precision, norm_type=norm_type, enable_sharded_grad_scaler=enable_sharded_grad_scaler, use_pure_fp16=use_pure_fp16, ) # No need to check for parameter and loss parity if expecting an error if expects_device_error: return # Check parameter devices are CPU if offloading to CPU before calling # `get_full_params()`, which will cast the parameters to FP32 if offload_params: for param in fsdp_model.parameters(): self.assertEqual(param.device, cpu_device) fsdp_loss = fsdp_loss.cuda() fsdp_unsharded_params = get_full_params(fsdp_model) torch.testing.assert_allclose(ref_loss, fsdp_loss) # Do not check for parameter parity if using mixed precision since (1) # the DDP parameters are in FP16 (from `half()`) while the FSDP # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs # the optimizer in FP16 while FSDP runs it in FP32 if mixed_precision is not None: self.assertEqual( ddp_params, fsdp_unsharded_params, exact_device=True, msg="FSDP did not match DDP", )
def _zero_model(fsdp_model: FullyShardedDataParallel): with fsdp_model.summon_full_params(): for param in fsdp_model.parameters(): with torch.no_grad(): param.zero_()
def _test_identical_outputs(self, model_init_fn, *args, ref_ddp_fn=None, num_steps=2, fsdp_init_mode=FSDPInitMode.CUDA_AFTER, lr=0.01, cpu_offload=CPUOffload(), backward_prefetch=None, sharding_strategy=None, save_model=True, clip_norm=0.3, norm_type=None, **kwargs): group = dist.distributed_c10d._get_default_group() rank = group.rank() # Establish reference behavior with PyTorch DDP (+ optionally autocast). model = model_init_fn(group=group, wrap_fsdp=False).cuda() if ref_ddp_fn is None: model = nn.parallel.DistributedDataParallel(model, device_ids=[rank], output_device=rank) else: model = ref_ddp_fn(model) # DDP training ref_loss = self._train_for_several_steps(model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload) ref_full_params = list(model.parameters()) # Confirm we get the same behavior using FullyShardedDataParallel. try: model = model_init_fn( group=group, wrap_fsdp=True, fsdp_init_mode=fsdp_init_mode, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) except Exception as e: raise ValueError( f"model_Init_fn {model_init_fn} got error {str(e)}") cpu_offload = cpu_offload or CPUOffload() # disabled if not specified. model = FullyShardedDataParallel( model, cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, sharding_strategy=sharding_strategy, ) # Call model.cuda() after init FSDP if specified. if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: model = model.cuda() # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we # expect FSDP code to raise error that we check below, in the case of # offload params. if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: for p in model.parameters(): # Should be on CPU regardless of if param is sharded. self.assertEqual(p.device, torch.device("cpu"), f"Mismatch, cpu offload is {cpu_offload}") only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params ctx = (self.assertRaisesRegex(AssertionError, "Expected param to be on CPU") if only_check_err else suppress()) with ctx: # FSDP training shard_loss = self._train_for_several_steps( model, num_steps, autocast=False, lr=lr, fsdp_cpu_offload=cpu_offload, save_model=save_model, ) # We only check for errors in the case we have the following setup: # model = FSDP(model, cpu_offload=True) # model = model.cuda() # so skip the rest of this logic. if only_check_err: return # If CPU offload, next call will change model params to GPU. Sanity # check that params are on CPU before. if cpu_offload.offload_params: device_set = {p.device for p in model.parameters()} self.assertEqual({torch.device("cpu")}, device_set, f"Got device set {device_set}") shard_full_params = get_full_params(model) if cpu_offload.offload_params: shard_loss = shard_loss.cuda() torch.testing.assert_allclose(ref_loss, shard_loss) self.assertEqual( ref_full_params, shard_full_params, exact_device=True, msg="FullyShardedDataParallel didn't match PyTorch DDP", )
def _get_full_detached_param(fsdp_model: FullyShardedDataParallel): with fsdp_model.summon_full_params(): params = list(p.clone().detach_() for p in fsdp_model.parameters()) return params