def _run_test_summon_full_param_writeback(cls, writeback, modify_outer, *fsdp_args, **fsdp_kwargs): with enable_wrap(wrapper_cls=FSDP, *fsdp_args, **fsdp_kwargs): lin1 = wrap(nn.Linear(5, 5, bias=False).cuda(cls.rank)) lin2 = nn.Linear(5, 3, bias=False).cuda(cls.rank) model = wrap(nn.Sequential(lin1, lin2)) # set the value outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param") inner_param = model.get_parameter( "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param") p = outer_param if modify_outer else inner_param with torch.no_grad(): # This sets the local shard value p[0] = cls.rank + 2 with model.summon_full_params(model, writeback=writeback): with torch.no_grad(): p.copy_(torch.zeros_like(p)) if writeback or cls.world_size == 1: # When world_size = 1, FSDP does not shard and parameter is not set to # a local shard, so write is always reflected. cls.assertEqual(p.cpu()[0], 0) else: cls.assertEqual(p.cpu()[0], cls.rank + 2)
def __init__(self, device): super().__init__() self.lin1 = MyLinear(2, 2, bias=False, device=device) self.lin1 = wrap(self.lin1) self.lin2 = MyLinear(2, 2, bias=False, device=device) self.l3 = MyModel(device=device) self.l3 = wrap(self.l3)
def configure_sharded_model(self) -> None: # the model is already wrapped with FSDP: no need to wrap again! if isinstance(self.layer, FullyShardedDataParallel): return for i, layer in enumerate(self.layer): if i % 2 == 0: self.layer[i] = wrap(layer) self.layer = wrap(self.layer)
def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40, ) fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": my_auto_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping # policy does not exceed the parameter threshold before the inner # sequential (`sequential[2]`) anymore; hence, it flattens # `sequential[0]` and `sequential[2][0]` into `model` and leaves # `sequential[1]` and `sequential[2][1]` as-is since they are ignored self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], nn.Sequential)) self.assertTrue(isinstance(model.module[2][0], nn.Linear)) self.assertTrue(isinstance(model.module[2][1], nn.Linear))
def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertTrue(layer.process_group is new_process_group) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2)
def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None): if auto_wrap: module = meta_module_fn() is_meta = next(module.parameters()).is_meta fsdp_meta = FSDP( module, auto_wrap_policy=always_wrap, param_init_fn=init_fn, ) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) module_regular = NestedModel(device="cuda") _reset_params_if_meta(is_meta, module_regular) fsdp_regular = FSDP( module_regular, auto_wrap_policy=always_wrap, ) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) else: with enable_wrap( wrapper_cls=FSDP, param_init_fn=init_fn, ): module = meta_module_fn() is_meta = next(module.parameters()).is_meta # Non FSDP modules will still be initialized because they bubble up # to be part of a larger FSDP unit. fsdp_meta = wrap(module) meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3) # Init and reset parameters before wrapping so that reset_params # matches up with meta device's initialization. module_regular = NestedModel(device="cuda") _reset_params_if_meta(is_meta, module_regular) with enable_wrap(wrapper_cls=FSDP): module_regular.lin1 = wrap(module_regular.lin1) module_regular.l3 = wrap(module_regular.l3) fsdp_regular = wrap(module_regular) regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3) # Compare it before training self._compare_fsdp(fsdp_meta, fsdp_regular) inp = torch.randn(10, 2, device='cuda') fsdp_meta(inp).sum().backward() fsdp_regular(inp).sum().backward() meta_opt.step() regular_opt.step() self._compare_fsdp(fsdp_meta, fsdp_regular)
def test_state_dict_type(self): module = SkipModel(double_nest=True) with enable_wrap(wrapper_cls=FSDP): fsdp = wrap(module) with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT): pass for module in FSDP.fsdp_modules(fsdp): self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)
def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP(nn.Linear(5, 5), process_group=self.process_group, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=1)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size())
def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear))
def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names
def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": always_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # All non-ignored modules should be wrapped with FSDP self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], FSDP)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], FSDP))
def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg)
def __init__(self, double_nest): super().__init__() self.linear = nn.Linear(10, 10, bias=False).cuda() self.linear_skip = SkipModule().cuda() self.nested_linear = wrap(NestedLinear(fsdp_wrap=double_nest))
def __init__(self, fsdp_wrap): super().__init__() if fsdp_wrap: self.nested_linear = wrap(nn.Linear(10, 10, bias=False).cuda()) else: self.nested_linear = nn.Linear(10, 10, bias=False).cuda()