def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload): # CPU offload and CUDA after don't work together as expected. if ( cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER ): return device = torch.device("cuda") torch.cuda.set_device(0) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( default_auto_wrap_policy, min_num_params=40 ) model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() del os.environ["MASTER_ADDR"] del os.environ["MASTER_PORT"]
def test_main_wrap_api(self, cpu_offload, fsdp_init_mode): if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, fsdp_auto_wrap_policy=functools.partial( default_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, ) if fsdp_init_mode == FSDPInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules = [ wrapped_model, wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4, # Nested FSDP wrapped_model.module.lin4.module.nested_lin, ] for module in modules: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step()