def test_auto_wrap_preset_force_leaf_custom(self, wrap_method): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) if wrap_method == WrapMethod.WRAP_API: with enable_wrap( auto_wrap_policy=my_auto_wrap_policy, wrapper_cls=FSDP, process_group=self.process_group, ): model = auto_wrap(sequential) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList))
def test_auto_wrap_preset_exclude_wrap(self, wrap_method): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Note that outermost module will be FSDP instance for FSDP_CTOR # approach, because we need to call the FSDP ctor so the returned obj # will be an FSDP instance. If we don't want to shard the outermost # module based on policy, we can apply the policy manually to the # outermost instance and skip the sharding. if wrap_method == WrapMethod.WRAP_API: self.assertTrue(isinstance(model, nn.ModuleList)) else: self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear))
def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload): # CPU offload and CUDA after don't work together as expected. if ( cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER ): return device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def test_main_wrap_api(self, cpu_offload, fsdp_init_mode): if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, fsdp_auto_wrap_policy=functools.partial( default_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, ) if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules = [ wrapped_model, wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4, # Nested FSDP wrapped_model.module.lin4.module.nested_lin, ] for module in modules: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step()
def __init__(self, nested): super().__init__() # TODO: test the various init modes. move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE # if nested=True, the FSDP module will be nested one layer deep # and we should pick that up. if nested: self.lin1 = nn.Sequential( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda), FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)), ) else: self.lin1 = FSDP( _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda) ) self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)) self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
def test_error_already_wrapped(self, nested, fsdp_init_mode): """ Test that an error is raised if we attempt to wrap when submodules are already FSDP. """ wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, fsdp_init_mode=fsdp_init_mode) if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: wrapped_fsdp = wrapped_fsdp.cuda() with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"): mod = FSDP(wrapped_fsdp, fsdp_auto_wrap_policy=default_auto_wrap_policy)
def test_auto_wrap_api(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP))
def test_wrap(self, wrap_method): if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): layer = wrap(nn.Linear(5, 5)) else: assert wrap_method == WrapMethod.FSDP_CTOR layer = FSDP( nn.Linear(5, 5), process_group=self.process_group, fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1) ) self.assertTrue(isinstance(layer, FSDP)) self.assertEqual(layer.rank, self.process_group.rank()) self.assertEqual(layer.world_size, self.process_group.size())
def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
def test_basic_checkpoint_end_to_end(self, cpu_offload): seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) # Runs FSDP with no checkpointing fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # Runs checkpoint-wrapped FSDP checkpointed_fsdp = checkpoint_wrapper(FSDP(deepcopy(seq), cpu_offload=cpu_offload)) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP(checkpoint_wrapper(deepcopy(seq)), cpu_offload=cpu_offload) # Runs FSDP with manual calls to checkpoint. fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) models = [ fsdp_only_seq, checkpointed_fsdp, fsdp_wrapped_checkpoint, fsdp_call_checkpoint ] for _ in range(6): losses = [] outputs = [] for m in models: if m == fsdp_call_checkpoint: out = checkpoint(m, inp) else: out = m(inp) loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) self._verify_parity(losses, outputs, models)
def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear))
def test_auto_wrap_preset_force_leaf_custom(self): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union( {nn.Linear} ), ) sequential = nn.Sequential( nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)]) ) model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList))
def test_auto_wrap_foo(self, wrap_method): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
def test_auto_wrap_preset_force_leaf(self, wrap_method): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
def test_auto_wrap_preset_exclude_wrap_include_children(self, wrap_method): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) if wrap_method == WrapMethod.WRAP_API: self.assertTrue(isinstance(model, nn.ModuleList)) else: self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP))
def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations): seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device()) # Runs FSDP with no checkpointing fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # Runs checkpoint-wrapped FSDP checkpointed_fsdp = checkpoint_wrapper( FSDP(deepcopy(seq), cpu_offload=cpu_offload), offload_to_cpu=offload_activations, ) # Runs FSDP-wrapped checkpointed module fsdp_wrapped_checkpoint = FSDP( checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations), cpu_offload=cpu_offload, ) # Runs FSDP with manual calls to checkpoint. fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload) # note that reentrant-based checkpointing requires inputs to have grad # flag set. inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True) models = [ fsdp_only_seq, checkpointed_fsdp, fsdp_wrapped_checkpoint, fsdp_call_checkpoint, ] offload_to_cpu_event = "Memcpy DtoH" for i in range(6): losses = [] outputs = [] for m in models: check_offload = m != fsdp_only_seq and i == 0 and offload_activations profiler_ctx = ( torch.profiler.profile(use_cuda=True) if check_offload else contextlib.suppress() ) with profiler_ctx as prof: if m == fsdp_call_checkpoint: offload_ctx = ( torch.autograd.graph.save_on_cpu(pin_memory=True) if offload_activations else contextlib.suppress() ) with offload_ctx: out = checkpoint(m, inp) else: out = m(inp) if check_offload: event_names = [event.name for event in prof.events()] offload_occured = any( offload_to_cpu_event in name for name in event_names ) self.assertTrue(offload_occured) loss = out.sum() loss.backward() losses.append(loss) outputs.append(out) self._verify_parity(losses, outputs, models)