def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload, use_device_id): # CPU offload and CUDA after don't work together as expected. if (cpu_offload.offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER): return device = torch.device("cuda") torch.cuda.set_device(0) device_id = (torch.device("cuda", torch.cuda.current_device()) if use_device_id else None) # Random port in case the next test run quickly, same port would cause conflict. os.environ["MASTER_ADDR"] = "localhost" os.environ["MASTER_PORT"] = str(find_free_port()) file_name = tempfile.NamedTemporaryFile(delete=False).name torch.distributed.init_process_group( backend="nccl", init_method=f"{FILE_SCHEMA}_{file_name}", rank=0, world_size=1, ) # NOTE: We move model to CUDA after init with FSDP to simulate real use # cases where full model cannot be loaded onto GPU, but their shards can. cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER try: sequential = TestFSDPWrap.NestedSequentialModel.get_model( cuda=(not cuda_after_init)) my_auto_wrap_policy = functools.partial( size_based_auto_wrap_policy, min_num_params=40) model = FSDP(sequential, cpu_offload=cpu_offload, auto_wrap_policy=my_auto_wrap_policy, device_id=device_id) TestFSDPWrap.NestedSequentialModel.verify_model(self, model) if cuda_after_init: model = model.cuda() input = torch.rand((1, 5), dtype=torch.float).to(device) output = model(input) loss = F.mse_loss(input, output) loss.backward() finally: torch.distributed.destroy_process_group() try: os.remove(file_name) except FileNotFoundError: pass
def test_main_wrap_api(self, cpu_offload, backward_prefetch, forward_prefetch, cuda_init_mode): if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params: # they don't work together, expected return move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE class Nested(nn.Module): def __init__(self): super().__init__() self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) def forward(self, input): return self.nested_lin(input) class MyModel(nn.Module): def __init__(self): super().__init__() self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda) self.lin4 = Nested() def forward(self, input): return self.lin4(self.lin3(self.lin2(self.lin1(input)))) model = MyModel() wrapped_model = FSDP( model, auto_wrap_policy=functools.partial( size_based_auto_wrap_policy, min_num_params=0, # wrap all modules ), cpu_offload=cpu_offload, backward_prefetch=backward_prefetch, forward_prefetch=forward_prefetch, ) if cuda_init_mode == CUDAInitMode.CUDA_AFTER: wrapped_model = wrapped_model.cuda() modules_in_fsdp_graph_order = [ wrapped_model.module.lin1, wrapped_model.module.lin2, wrapped_model.module.lin3, wrapped_model.module.lin4.module.nested_lin, wrapped_model.module.lin4, wrapped_model ] for module in modules_in_fsdp_graph_order: self.assertTrue(isinstance(module, FSDP)) self._check_cpu_offload(module, cpu_offload) self._check_backward_prefetch(module, backward_prefetch) self._check_forward_prefetch(module, forward_prefetch) # Run model a few times for sanity check. optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9) inp = torch.ones(1).cuda() for _ in range(6): optim.zero_grad() loss = wrapped_model(inp).sum() loss.backward() optim.step() # Since we ran with backward prefetch, verify backward prefetch related # data. for i, module in enumerate(modules_in_fsdp_graph_order): self.assertEqual(i, module._my_fsdp_idx_in_graph) self.assertTrue( module._fsdp_graph_order == modules_in_fsdp_graph_order)