class TestClipGradNorm(FSDPTest): def _run_fsdp_one_iteration(self, norm_type, nested_fsdp, cpu_offload): """Test FSDP with clip grad norm.""" fsdp_model = DeterministicModel(nested_fsdp, cpu_offload=cpu_offload) local_model = DeterministicModel(False) input = torch.rand(14, 2, device=self.rank) fsdp_model = FSDP(fsdp_model, cpu_offload=cpu_offload) self.assertTrue(len(input) >= self.world_size) out = local_model(input[:self.world_size]) out.sum().backward() in_data = torch.tensor(input[self.rank], device=self.rank) out_fsdp = fsdp_model(in_data) out_fsdp.sum().backward() total_norms_fsdp = _collect_total_grad_norm_fsdp( fsdp_model, norm_type, self.rank) total_norms_local = _collect_total_grad_norm_local( local_model, norm_type) total_norms_local /= self.world_size norm_cap = total_norms_fsdp / 2.0 self.assertEqual(total_norms_local, total_norms_fsdp) fsdp_model.clip_grad_norm_(norm_cap, norm_type=norm_type) nn_utils.clip_grad_norm_(local_model.parameters(), norm_cap, norm_type=norm_type) total_norms_after_clip_fsdp = _collect_total_grad_norm_fsdp( fsdp_model, norm_type, self.rank) total_norms_after_clip_local = _collect_total_grad_norm_local( local_model, norm_type) self.assertTrue(total_norms_after_clip_fsdp <= norm_cap) self.assertEqual(total_norms_after_clip_local, total_norms_after_clip_fsdp) @skip_if_lt_x_gpu(2) @parametrize("norm_type", [2.0, inf]) @parametrize("nested_fsdp", [True, False]) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) def test_fsdp_clip_grad_norm(self, norm_type, nested_fsdp, cpu_offload): """Test FSDP with clip grad norm.""" self._run_fsdp_one_iteration(norm_type, nested_fsdp, cpu_offload)
def register_strategies(cls, strategy_registry: Dict) -> None: if _TORCH_GREATER_EQUAL_1_12: strategy_registry.register( "fsdp_native", cls, description="Fully Sharded Data Parallel training from torch.distributed.", ) cls._registered_strategies.append("fsdp_native") strategy_registry.register( "fsdp_native_full_shard_offload", cls, description="Native FSDP with Full Sharding and CPU Offloading", cpu_offload=CPUOffload(offload_params=True), ) cls._registered_strategies.append("fsdp_native_full_shard_offload")
class TestFSDPWrap(FSDPTest): """ Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into FSDP constructor. """ def setUp(self) -> None: super().setUp() class NestedSequentialModel: @staticmethod def get_model(cuda=True): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)), ) if cuda: sequential = sequential.cuda() return sequential @staticmethod def verify_model_all_wrapped(cls, model): cls.assertTrue(isinstance(model, FSDP)) cls.assertTrue(isinstance(model.module[0], FSDP)) cls.assertTrue(isinstance(model.module[1], FSDP)) cls.assertTrue(isinstance(model.module[2], FSDP)) cls.assertTrue(isinstance(model.module[2].module[0], FSDP)) cls.assertTrue(isinstance(model.module[2].module[1], FSDP)) @staticmethod def verify_model(cls, model): cls.assertTrue(isinstance(model, FSDP)) cls.assertTrue(isinstance(model.module[0], nn.Linear)) cls.assertTrue(isinstance(model.module[1], nn.Linear)) cls.assertTrue(isinstance(model.module[2], FSDP)) # following modules were not wrapped by the policy. cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear)) def _get_linear(self, fin, fout): return nn.Linear(fin, fout, bias=False) def _get_already_wrapped_fsdp(self, cuda_init_mode=CUDAInitMode.CUDA_BEFORE, nested=False) -> FSDP: fn_self = self class MyModel(nn.Module): def __init__(self, nested): super().__init__() # TODO: test the various init modes. move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE # if nested=True, the FSDP module will be nested one layer deep # and we should pick that up. if nested: self.lin1 = nn.Sequential( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda), FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)), ) else: self.lin1 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) self.lin2 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) self.lin3 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) def forward(self, input: torch.Tensor) -> torch.Tensor: return self.lin3(self.lin2(self.lin1(input))) model = MyModel(nested=nested) return model @skip_if_lt_x_gpu(2) @parametrize("nested", [True, False]) @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) def test_error_already_wrapped(self, nested, cuda_init_mode): """ Test that an error is raised if we attempt to wrap when submodules are already FSDP. """ wrapped_fsdp = self._get_already_wrapped_fsdp( nested=nested, cuda_init_mode=cuda_init_mode) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_fsdp = wrapped_fsdp.cuda() with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"): mod = FSDP(wrapped_fsdp, auto_wrap_policy=size_based_auto_wrap_policy) @skip_if_lt_x_gpu(2) @parametrize("use_or_policy", [True, False]) def test_wrap_batchnorm_individually(self, use_or_policy): def never_wrap_policy(*args, **kwargs): return False policy = (functools.partial( _or_policy, policies=[never_wrap_policy, _wrap_batchnorm_individually]) if use_or_policy else _wrap_batchnorm_individually) model = BatchNormNet() fsdp = FSDP(model, auto_wrap_policy=policy) # Batchnorms should be wrapped for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]: self.assertTrue(isinstance(layer, FSDP)) self.assertFalse(isinstance(fsdp.lin, FSDP)) @skip_if_lt_x_gpu(2) def test_bn_always_wrapped_individually(self): """ Ensures that by using _or_policy with _wrap_batchnorm_individually, even if the other policy results in a module containing a BN unit being wrapped, the contained BN unit will still be individually wrapped. """ class MyModule(nn.Module): def __init__(self): super().__init__() self.bn_container = BatchNormNet() def wrap_bn_container(module, recurse, *args, **kwargs): if recurse: return True return isinstance(module, BatchNormNet) my_policy = functools.partial( _or_policy, policies=[wrap_bn_container, _wrap_batchnorm_individually]) mod = MyModule() fsdp = FSDP(mod, auto_wrap_policy=my_policy) # Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN)))) # and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner # BN is not individually wrapped.) for bn in [ fsdp.bn_container.bn1, fsdp.bn_container.bn2, fsdp.bn_container.bn3, fsdp.bn_container.sync_bn ]: self.assertTrue(isinstance(bn, FSDP)) # if we just wrapped BN container, individual batchnorms are not # wrapped. mod = MyModule() fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container) self.assertTrue(isinstance(mod.bn_container, FSDP)) for bn in [ fsdp.bn_container.bn1, fsdp.bn_container.bn2, fsdp.bn_container.bn3, fsdp.bn_container.sync_bn ]: self.assertFalse(isinstance(bn, FSDP)) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE]) @parametrize("forward_prefetch", [True, False]) @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_AFTER, CUDAInitMode.CUDA_BEFORE]) def test_main_wrap_api(self, cpu_offload, backward_prefetch, forward_prefetch, cuda_init_mode): if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, forward_prefetch=forward_prefetch, ) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules_in_fsdp_graph_order = [ wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4.module.nested_lin, wrapped_model.module.lin4, wrapped_model ] for module in modules_in_fsdp_graph_order: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) self._check_backward_prefetch(module, backward_prefetch) self._check_forward_prefetch(module, forward_prefetch) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step() # Since we ran with backward prefetch, verify backward prefetch related # data. for i, module in enumerate(modules_in_fsdp_graph_order): self.assertEqual(i, module._my_fsdp_idx_in_graph) self.assertTrue( module._fsdp_graph_order == modules_in_fsdp_graph_order)
class TestAutoWrap(TestCase): def setUp(self) -> None: super().setUp() # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP(nn.Linear(5, 5), process_group=self.process_group, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=1)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size()) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertTrue(layer.process_group is new_process_group) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") def test_always_wrap(self): """ Test to ensure that if `always_wrap_policy` is passed into FSDP, all submodules are wrapped. """ seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True) model = FSDP(seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped( self, model) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_transformer_auto_wrap_policy(self): """Tests the ``transformer_auto_wrap_policy``.""" auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) modules = list(fsdp_model.modules()) encoder_layers = set(fsdp_model.module.transformer.encoder.layers) decoder_layers = set(fsdp_model.module.transformer.decoder.layers) for module in modules: if module is fsdp_model or module in encoder_layers or module in decoder_layers: self.assertTrue(isinstance(module, FSDP)) else: self.assertFalse(isinstance(module, FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_api(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40, force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]) @parametrize("use_device_id", [True, False]) def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) device_id = (torch.device("cuda", torch.cuda.current_device()) if use_device_id else None) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) file_name = tempfile.NamedTemporaryFile(delete=False).name torch.distributed.init_process_group( backend="nccl", init_method=f"{FILE_SCHEMA}_{file_name}", rank=0, world_size=1, ) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, cpu_offload=cpu_offload, auto_wrap_policy=my_auto_wrap_policy, device_id=device_id) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() try: os.remove(file_name) except FileNotFoundError: pass @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": always_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # All non-ignored modules should be wrapped with FSDP self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], FSDP)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40, ) fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": my_auto_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping # policy does not exceed the parameter threshold before the inner # sequential (`sequential[2]`) anymore; hence, it flattens # `sequential[0]` and `sequential[2][0]` into `model` and leaves # `sequential[1]` and `sequential[2][1]` as-is since they are ignored self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], nn.Sequential)) self.assertTrue(isinstance(model.module[2][0], nn.Linear)) self.assertTrue(isinstance(model.module[2][1], nn.Linear))
class TestFSDPWrap(FSDPTest): """ Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into FSDP constructor. """ def setUp(self) -> None: super().setUp() class NestedSequentialModel: @staticmethod def get_model(cuda=True): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)), ) if cuda: sequential = sequential.cuda() return sequential @staticmethod def verify_model_all_wrapped(cls, model): cls.assertTrue(isinstance(model, FSDP)) cls.assertTrue(isinstance(model.module[0], FSDP)) cls.assertTrue(isinstance(model.module[1], FSDP)) cls.assertTrue(isinstance(model.module[2], FSDP)) cls.assertTrue(isinstance(model.module[2].module[0], FSDP)) cls.assertTrue(isinstance(model.module[2].module[1], FSDP)) @staticmethod def verify_model(cls, model): cls.assertTrue(isinstance(model, FSDP)) cls.assertTrue(isinstance(model.module[0], nn.Linear)) cls.assertTrue(isinstance(model.module[1], nn.Linear)) cls.assertTrue(isinstance(model.module[2], FSDP)) # following modules were not wrapped by the policy. cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear)) def _get_linear(self, fin, fout): return nn.Linear(fin, fout, bias=False) def _get_already_wrapped_fsdp(self, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, nested=False) -> FSDP: fn_self = self class MyModel(nn.Module): def __init__(self, nested): super().__init__() # TODO: test the various init modes. move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE # if nested=True, the FSDP module will be nested one layer deep # and we should pick that up. if nested: self.lin1 = nn.Sequential( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda), FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)), ) else: self.lin1 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) self.lin2 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) self.lin3 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) def forward(self, input: torch.Tensor) -> torch.Tensor: return self.lin3(self.lin2(self.lin1(input))) model = MyModel(nested=nested) return model @skip_if_lt_x_gpu(2) @parametrize("nested", [True, False]) @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]) def test_error_already_wrapped(self, nested, fsdp_init_mode): """ Test that an error is raised if we attempt to wrap when submodules are already FSDP. """ wrapped_fsdp = self._get_already_wrapped_fsdp( nested=nested, fsdp_init_mode=fsdp_init_mode) if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: wrapped_fsdp = wrapped_fsdp.cuda() with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"): mod = FSDP(wrapped_fsdp, auto_wrap_policy=default_auto_wrap_policy) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]) @parametrize( "backward_prefetch", [BackwardPrefetch.BACKWARD_POST, BackwardPrefetch.BACKWARD_PRE]) @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]) def test_main_wrap_api(self, cpu_offload, backward_prefetch, fsdp_init_mode): if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, auto_wrap_policy=functools.partial( default_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, ) if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules_in_fsdp_graph_order = [ wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4.module.nested_lin, wrapped_model.module.lin4, wrapped_model ] for module in modules_in_fsdp_graph_order: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) self._check_backward_prefetch(module, backward_prefetch) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step() # Since we ran with backward prefetch, verify backward prefetch related # data. for i, module in enumerate(modules_in_fsdp_graph_order): self.assertEqual(i, module._my_fsdp_idx_in_graph) self.assertTrue( module._fsdp_graph_order == modules_in_fsdp_graph_order)
class TestAutoWrap(TestCase): def setUp(self) -> None: super().setUp() # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1) @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP(nn.Linear(5, 5), process_group=self.process_group, fsdp_auto_wrap_policy=functools.partial( default_auto_wrap_policy, min_num_params=1)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size()) def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear)) def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertTrue(layer.process_group is new_process_group) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2) def test_auto_wrap_api(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear)) def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP)) def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_BEFORE, FSDPInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]) def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) file_name = tempfile.NamedTemporaryFile(delete=False).name torch.distributed.init_process_group( backend="nccl", init_method=f"{FILE_SCHEMA}_{file_name}", rank=0, world_size=1, ) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() try: os.remove(file_name) except FileNotFoundError: pass
class TestFSDPCheckpoint(FSDPTest): class SequentialModule(nn.Module): def __init__( self, checkpoint_layer=False, offload_activations=False, wrap_fsdp=False, *fsdp_args, **fsdp_kwargs, ): torch.manual_seed(0) torch.cuda.manual_seed(0) super().__init__() l1 = nn.Linear(3, 3).cuda() l2 = nn.Linear(3, 3).cuda() l3 = nn.Linear(3, 3).cuda() if checkpoint_layer: ckpt_wrapper = partial( checkpoint_wrapper, offload_to_cpu=offload_activations ) l1 = ckpt_wrapper(l1) l2 = ckpt_wrapper(l2) l3 = ckpt_wrapper(l3) fsdp_wrapper = partial( _maybe_wrap_fsdp, wrap_fsdp=wrap_fsdp, *fsdp_args, **fsdp_kwargs ) self.ffn = nn.Sequential( fsdp_wrapper(l1), fsdp_wrapper(l2), fsdp_wrapper(l3), ) def forward(self, x): return self.ffn(x) def _verify_parity(self, losses, outputs, models): assert losses assert outputs assert models for (l, o) in zip(losses[1:], outputs[1:]): self.assertEqual(losses[0], l) self.assertEqual(outputs[0], o) # Verify grads ref_model = models[0] ref_grads = [p.grad for p in ref_model.parameters()] for m in models[1:]: grads = [p.grad for p in m.parameters()] for ref_g, g in zip(ref_grads, grads): self.assertEqual(ref_g, g) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations): # Test checkpoint(FSDP(layer1), FSDP(layer2), ....) ckpt_sequential_wrapped_fsdp = checkpoint_wrapper( TestFSDPCheckpoint.SequentialModule( wrap_fsdp=True, cpu_offload=cpu_offload ), offload_to_cpu=offload_activations, ) # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), .... inner_ckpt = TestFSDPCheckpoint.SequentialModule( checkpoint_layer=True, offload_activations=offload_activations, wrap_fsdp=True, cpu_offload=cpu_offload, ) baseline = TestFSDPCheckpoint.SequentialModule( wrap_fsdp=True, cpu_offload=cpu_offload ) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) global _save_on_cpu_called models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline] with patch_save_on_cpu(get_patched_save_on_cpu()): for i in range(2): losses = [] outputs = [] for m in models: check_offload = m != baseline and i == 0 and offload_activations if check_offload: self.assertFalse(_save_on_cpu_called) out = m(inp) if check_offload: self.assertTrue(_save_on_cpu_called) _save_on_cpu_called = False loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) self._verify_parity(losses, outputs, models) @skip_if_lt_x_gpu(2) @parametrize( "cpu_offload", [CPUOffload(offload_params=True), CPUOffload(offload_params=False)], ) @parametrize("offload_activations", [True, False]) def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): global _save_on_cpu_called with patch_save_on_cpu(get_patched_save_on_cpu()): seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) # Runs FSDP with no checkpointing fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # Runs checkpoint-wrapped FSDP checkpointed_fsdp = checkpoint_wrapper( FSDP(deepcopy(seq), cpu_offload=cpu_offload), offload_to_cpu=offload_activations, ) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP( checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations), cpu_offload=cpu_offload, ) # Runs FSDP with manual calls to checkpoint. fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) models = [ fsdp_only_seq, checkpointed_fsdp, fsdp_wrapped_checkpoint, fsdp_call_checkpoint, ] # Ensure _save_on_cpu is not yet called self.assertFalse(_save_on_cpu_called) for i in range(6): losses = [] outputs = [] for m in models: check_offload = m != fsdp_only_seq and i == 0 and offload_activations if m == fsdp_call_checkpoint: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) offload_ctx = ( get_patched_save_on_cpu()(pin_memory=True) if offload_activations else contextlib.suppress() ) with offload_ctx: out = checkpoint(m, inp) else: # _save_on_cpu should not be called yet self.assertFalse(_save_on_cpu_called) out = m(inp) if check_offload: self.assertTrue(_save_on_cpu_called) loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) _save_on_cpu_called = False self._verify_parity(losses, outputs, models)