Ejemplo n.º 1
0
    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload,
                                  use_device_id):
        # CPU offload and CUDA after don't work together as expected.
        if (cpu_offload.offload_params
                and cuda_init_mode == CUDAInitMode.CUDA_AFTER):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        device_id = (torch.device("cuda", torch.cuda.current_device())
                     if use_device_id else None)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())

        file_name = tempfile.NamedTemporaryFile(delete=False).name
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=f"{FILE_SCHEMA}_{file_name}",
            rank=0,
            world_size=1,
        )

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
                cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                size_based_auto_wrap_policy, min_num_params=40)
            model = FSDP(sequential,
                         cpu_offload=cpu_offload,
                         auto_wrap_policy=my_auto_wrap_policy,
                         device_id=device_id)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()

        try:
            os.remove(file_name)
        except FileNotFoundError:
            pass
Ejemplo n.º 2
0
    def test_main_wrap_api(self, cpu_offload, backward_prefetch,
                           forward_prefetch, cuda_init_mode):

        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
            # they don't work together, expected
            return

        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        class Nested(nn.Module):
            def __init__(self):
                super().__init__()
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                              move_to_cuda)

            def forward(self, input):
                return self.nested_lin(input)

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin4 = Nested()

            def forward(self, input):
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))

        model = MyModel()
        wrapped_model = FSDP(
            model,
            auto_wrap_policy=functools.partial(
                size_based_auto_wrap_policy,
                min_num_params=0,  # wrap all modules
            ),
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            forward_prefetch=forward_prefetch,
        )
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            wrapped_model = wrapped_model.cuda()

        modules_in_fsdp_graph_order = [
            wrapped_model.module.lin1, wrapped_model.module.lin2,
            wrapped_model.module.lin3,
            wrapped_model.module.lin4.module.nested_lin,
            wrapped_model.module.lin4, wrapped_model
        ]

        for module in modules_in_fsdp_graph_order:
            self.assertTrue(isinstance(module, FSDP))
            self._check_cpu_offload(module, cpu_offload)
            self._check_backward_prefetch(module, backward_prefetch)
            self._check_forward_prefetch(module, forward_prefetch)

        # Run model a few times for sanity check.
        optim = torch.optim.SGD(wrapped_model.parameters(),
                                lr=1e-2,
                                momentum=0.9)
        inp = torch.ones(1).cuda()
        for _ in range(6):
            optim.zero_grad()
            loss = wrapped_model(inp).sum()
            loss.backward()
            optim.step()

        # Since we ran with backward prefetch, verify backward prefetch related
        # data.
        for i, module in enumerate(modules_in_fsdp_graph_order):
            self.assertEqual(i, module._my_fsdp_idx_in_graph)
            self.assertTrue(
                module._fsdp_graph_order == modules_in_fsdp_graph_order)