def test_auto_wrap_smoke_test(self): device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) try: with enable_wrap(wrapper_cls=FSDP): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) model.to(device) input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def test_auto_wrap_preset_force_leaf_custom(self, wrap_method): """ Test to ensure force-leaf modules are not wrapped. """ my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40, force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES. union({nn.Linear}), ) sequential = nn.Sequential(nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])) if wrap_method == WrapMethod.WRAP_API: with enable_wrap( auto_wrap_policy=my_auto_wrap_policy, wrapper_cls=FSDP, process_group=self.process_group, ): model = auto_wrap(sequential) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Model was wrapped in FSDP as no inner modules were wrapped. self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.ModuleList))
def test_auto_wrap_preset_exclude_wrap(self, wrap_method): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) # Note that outermost module will be FSDP instance for FSDP_CTOR # approach, because we need to call the FSDP ctor so the returned obj # will be an FSDP instance. If we don't want to shard the outermost # module based on policy, we can apply the policy manually to the # outermost instance and skip the sharding. if wrap_method == WrapMethod.WRAP_API: self.assertTrue(isinstance(model, nn.ModuleList)) else: self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear))
def test_auto_wrap_preset_exclude_wrap_include_children(self): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, nn.ModuleList)) self.assertTrue(isinstance(model[0], FSDP))
def test_auto_wrap_preset_exclude_wrap(self): """ Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict} """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, nn.ModuleList)) self.assertTrue(isinstance(model[0], nn.Linear)) self.assertTrue(isinstance(model[1], nn.Linear))
def test_auto_wrap_preset_force_leaf(self): """ Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model.module[0], FSDP)) # Assert children of multihead attention are not wrapped self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention)) self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
def test_auto_wrap(self): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): sequential = nn.Sequential( nn.Linear(5, 5), nn.Linear(5, 5), nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5))) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model.module[0], nn.Linear)) self.assertTrue(isinstance(model.module[1], nn.Linear)) self.assertTrue(isinstance(model.module[2], FSDP)) self.assertTrue(isinstance(model.module[2].module[0], nn.Linear)) self.assertTrue(isinstance(model.module[2].module[1], nn.Linear))
def test_auto_wrap_smoke_test(self, wrap_method, fsdp_init_mode, cpu_offload): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, cpu_offload=cpu_offload): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def test_auto_wrap_foo(self, wrap_method): """ Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params. ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do. """ sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
def test_auto_wrap_preset_exclude_wrap_include_children(self, wrap_method): """ Test to ensure excluded modules are not wrapped, but children are if param size is greater than min_num_params """ sequential = nn.ModuleList([nn.Linear(10, 10)]) my_auto_wrap_policy = functools.partial(default_auto_wrap_policy, min_num_params=40) if wrap_method == WrapMethod.WRAP_API: with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group): model = auto_wrap(sequential, auto_wrap_policy=my_auto_wrap_policy) else: assert wrap_method == WrapMethod.FSDP_CTOR model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy) if wrap_method == WrapMethod.WRAP_API: self.assertTrue(isinstance(model, nn.ModuleList)) else: self.assertTrue(isinstance(model, FSDP)) self.assertTrue(isinstance(model[0], FSDP))