def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2)
def test_scaling_unscaling_sparse(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) inv_scale = torch.full((1, ), 0.5, dtype=torch.float, device="cpu") found_inf = torch.full((1, ), 0, dtype=torch.float, device="cpu") i = torch.tensor([[0, 1, 1], [2, 0, 2]], device="cpu", dtype=torch.int64) v = torch.tensor([16.0, 32.0, 64.0], dtype=torch.float, device="cpu") s = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float) # unscale sparse tensors s1 = s.clone() s1.grad = s.clone() opt = torch.optim.SGD([s1], lr=1.0) found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device] self.assertEqual(found_inf, 0.0) self.assertEqual(s1.grad.to_dense(), (s / 2).to_dense()) # unscale sparse tensor: inf v = torch.tensor([16.0, 32.0, float('inf')], dtype=torch.float, device="cpu") s1.grad = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float) found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device] self.assertEqual(found_inf, 1.0) # unscale sparse tensor: overflow (marked as inf) i = torch.tensor([[1, 1, 1], [0, 0, 2]], device="cpu", dtype=torch.int64) # coalescing sparse tensor here will cause the value to be Inf v = torch.tensor([2**15, 2**15, 1.0], dtype=torch.float16, device="cpu") s1 = torch.sparse_coo_tensor(i, v, torch.Size([2, 3]), device="cpu", dtype=torch.float16) s1.grad = s1.clone() found_inf.zero_() found_inf = scaler._unscale_grads_(opt, inv_scale, found_inf)[s1.device] self.assertEqual(found_inf, 1.0)
def test_grad_scaling(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) t0 = torch.full((1,), 4.0, dtype=torch.float32, device="cpu") t1 = torch.full((1,), 8.0, dtype=torch.float32, device="cpu") outputs = [t1.clone(), (t0.clone(), t1.clone()), [t0.clone(), t1.clone()]] outputs = scaler.scale(outputs) self.assertTrue(outputs[0] == 16.0 and outputs[1][0] == 8.0 and outputs[1][1] == 16.0) self.assertTrue(outputs[2][0] == 8.0 and outputs[2][1] == 16.0) self.assertTrue(scaler._scale.device == t1.device)
def test_inf_gradients_skip_optim_step(self): pg = DummyProcessGroup(0, 1) scaler = ShardedGradScaler(init_scale=2.0, process_group=pg, enabled=True) loss = torch.full((1,), 4.0, dtype=torch.float32, device="cpu") t0 = torch.tensor([float('inf')], dtype=torch.float32, device="cpu") t0.grad = t0.clone() opt = torch.optim.SGD([t0], lr=1.0) scaler.scale(loss) ret_val = scaler.step(opt) self.assertTrue(ret_val is None)
def setUp(self) -> None: super().setUp() # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1)
class TestAutoWrap(TestCase): def setUp(self) -> None: super().setUp() # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1) @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP( nn.Linear(5, 5), process_group=self.process_group, fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1) ) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size()) def test_wrap_disabled_outside_context(self): layer = wrap(nn.Linear(5, 5)) self.assertTrue(isinstance(layer, nn.Linear)) def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2) def test_auto_wrap_api(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP( sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy ) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP( sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy ) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear)) def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP)) def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union( {nn.Linear} ), ) sequential = nn.Sequential( nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]) ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_BEFORE, FSDPInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)] ) def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload): # CPU offload and CUDA after don't work together as expected. if ( cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER ): return device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
class TestAutoWrap(TestCase): def setUp(self) -> None: super().setUp() # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP(nn.Linear(5, 5), process_group=self.process_group, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=1)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size()) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_wrap_disabled_outside_context(self): pg = self.process_group class MyModel(nn.Module): def __init__(self): super().__init__() self.lin = wrap(nn.Linear(5, 5), process_group=pg) model = MyModel() with enable_wrap(wrapper_cls=FSDP, process_group=pg): model = wrap(model) self.assertTrue(isinstance(model, FSDP)) self.assertFalse(isinstance(model.lin, FSDP)) self.assertTrue(isinstance(model.lin, nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertTrue(layer.process_group is new_process_group) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") def test_always_wrap(self): """ Test to ensure that if `always_wrap_policy` is passed into FSDP, all submodules are wrapped. """ seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True) model = FSDP(seq, process_group=self.process_group, auto_wrap_policy=always_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped( self, model) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_transformer_auto_wrap_policy(self): """Tests the ``transformer_auto_wrap_policy``.""" auto_wrap_policy = functools.partial( transformer_auto_wrap_policy, transformer_layer_cls={ TransformerEncoderLayer, TransformerDecoderLayer }, ) fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy} fsdp_model = TransformerWithSharedParams.init( self.process_group, FSDPInitMode.RECURSIVE, CUDAInitMode.CUDA_BEFORE, fsdp_kwargs, ) modules = list(fsdp_model.modules()) encoder_layers = set(fsdp_model.module.transformer.encoder.layers) decoder_layers = set(fsdp_model.module.transformer.decoder.layers) for module in modules: if module is fsdp_model or module in encoder_layers or module in decoder_layers: self.assertTrue(isinstance(module, FSDP)) else: self.assertFalse(isinstance(module, FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_api(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40, force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = FSDP(sequential, process_group=self.process_group, auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") @parametrize("cuda_init_mode", [CUDAInitMode.CUDA_BEFORE, CUDAInitMode.CUDA_AFTER]) @parametrize( "cpu_offload", [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]) @parametrize("use_device_id", [True, False]) def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) device_id = (torch.device("cuda", torch.cuda.current_device()) if use_device_id else None) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) file_name = tempfile.NamedTemporaryFile(delete=False).name torch.distributed.init_process_group( backend="nccl", init_method=f"{FILE_SCHEMA}_{file_name}", rank=0, world_size=1, ) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, cpu_offload=cpu_offload, auto_wrap_policy=my_auto_wrap_policy, device_id=device_id) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() try: os.remove(file_name) except FileNotFoundError: pass @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": always_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # All non-ignored modules should be wrapped with FSDP self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], FSDP)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], FSDP)) @unittest.skipIf(torch.cuda.device_count() < 2, "Requires at least 2 GPUs") @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API]) def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod): sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) ignored_modules = [sequential[1], sequential[2][0]] my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40, ) fsdp_kwargs = { "process_group": self.process_group, "auto_wrap_policy": my_auto_wrap_policy, "ignored_modules": ignored_modules, } if wrap_method == WrapMethod.FSDP_CTOR: model = FSDP(sequential, **fsdp_kwargs) elif wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs): model = wrap(sequential) else: assert 0, f"Unsupported wrap method: {wrap_method}" # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping # policy does not exceed the parameter threshold before the inner # sequential (`sequential[2]`) anymore; hence, it flattens # `sequential[0]` and `sequential[2][0]` into `model` and leaves # `sequential[1]` and `sequential[2][1]` as-is since they are ignored self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], nn.Sequential)) self.assertTrue(isinstance(model.module[2][0], nn.Linear)) self.assertTrue(isinstance(model.module[2][1], nn.Linear))
class TestAutoWrap(unittest.TestCase): def setUp(self) -> None: # For all the tests here, we use a fake group self.process_group = DummyProcessGroup(rank=0, size=1) def test_wrap(self): with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size()) def test_wrap_disabled_outside_context(self): layer = wrap(nn.Linear(5, 5)) self.assertTrue(isinstance(layer, nn.Linear)) def test_wrap_override_defaults(self): new_process_group = DummyProcessGroup(rank=0, size=2) with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5), process_group=new_process_group) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, 0) self.assertEqual(layer.world_size, 2) def test_auto_wrap(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], nn.Linear)) def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, nn.ModuleList)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear)) def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, nn.ModuleList)) self.assertTrue(isinstance(model[0], FSDP)) def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear)) def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) with enable_wrap( auto_wrap_policy=my_auto_wrap_policy, wrapper_cls=FSDP, process_group=self.process_group, ): sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) model = auto_wrap(sequential) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList)) @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA") def test_auto_wrap_smoke_test(self): device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: with enable_wrap(wrapper_cls=FSDP): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) model.to(device) input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]