def model_sharded_context(self) -> Generator: log.detail(f"{self.__class__.__name__}: entered model_sharded_context.") with enable_wrap( wrapper_cls=FullyShardedDataParallel, process_group=self.process_group, cpu_offload=self.cpu_offload, backward_prefetch=self.backward_prefetch, mixed_precision=self.mixed_precision_config, device_id=self.root_device.index, **self.kwargs, ): yield
def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP(nn.Linear(5, 5), process_group=self.process_group, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=1)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size())
def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear))
def _create_module(wrap_fsdp=True): LINEAR_SKIP = "linear_skip" ctx = enable_wrap(wrapper_cls=FSDP) if wrap_fsdp else suppress() with ctx: module = SkipModel(double_nest=double_nest) # Full name of linear_skip param tensors in SkipModel, as would be # stored in checkpoint. linear_skip_tensor_names = [ k for k in dict(module.named_parameters()).keys() if LINEAR_SKIP in k ] # skip SkipModule linear_skip = getattr(module, LINEAR_SKIP) delattr(module, LINEAR_SKIP) # Wrap FSDP fsdp = wrap(module) # reattach setattr(module, LINEAR_SKIP, linear_skip) return fsdp, linear_skip_tensor_names
def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": always_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # All non-ignored modules should be wrapped with FSDP self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], FSDP)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], FSDP))
def test_distributed_checkpoint(self, state_dict_type) -> None: with enable_wrap(wrapper_cls=FSDP): torch.manual_seed(100) model = wrap(SkipModel(double_nest=True)) torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) with tempfile.TemporaryDirectory() as path: paths = [path] dist.broadcast_object_list(paths) path = paths[0] writer = FileSystemWriter(path) reader = FileSystemReader(path) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = model.state_dict() save_state_dict(state_dict, writer) with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( new_model, state_dict_type): state_dict = new_model.state_dict() load_state_dict(state_dict, reader) new_model.load_state_dict(state_dict) with FullyShardedDataParallel.summon_full_params( model), FullyShardedDataParallel.summon_full_params(new_model): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params)