Esempio n. 1
0
class TestFSDPWrap(FSDPTest):
    """
    Tests main API for wrapping FSDP, which is to pass auto_wrap_policy into
    FSDP constructor.
    """

    def setUp(self) -> None:
        super().setUp()

    class NestedSequentialModel:
        @staticmethod
        def get_model(cuda=True):
            sequential = nn.Sequential(
                nn.Linear(5, 5),
                nn.Linear(5, 5),
                nn.Sequential(nn.Linear(5, 5), nn.Linear(5, 5)),
            )
            if cuda:
                sequential = sequential.cuda()
            return sequential

        @staticmethod
        def verify_model(cls, model):
            cls.assertTrue(isinstance(model, FSDP))
            cls.assertTrue(isinstance(model.module[0], nn.Linear))
            cls.assertTrue(isinstance(model.module[1], nn.Linear))
            cls.assertTrue(isinstance(model.module[2], FSDP))
            # following modules were not wrapped by the policy.
            cls.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
            cls.assertTrue(isinstance(model.module[2].module[1], nn.Linear))

    def _get_linear(self, fin, fout):
        return nn.Linear(fin, fout, bias=False)

    def _get_already_wrapped_fsdp(
        self, fsdp_init_mode=FSDPInitMode.CUDA_BEFORE, nested=False
    ) -> FSDP:
        fn_self = self

        class MyModel(nn.Module):
            def __init__(self, nested):
                super().__init__()
                # TODO: test the various init modes.
                move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
                # if nested=True, the FSDP module will be nested one layer deep
                # and we should pick that up.
                if nested:
                    self.lin1 = nn.Sequential(
                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),
                        FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)),
                    )
                else:
                    self.lin1 = FSDP(
                        _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)
                    )
                self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
                self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))

            def forward(self, input: torch.Tensor) -> torch.Tensor:
                return self.lin3(self.lin2(self.lin1(input)))

        model = MyModel(nested=nested)
        return model

    @skip_if_lt_x_gpu(2)
    @parametrize("nested", [True, False])
    @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE])
    def test_error_already_wrapped(self, nested, fsdp_init_mode):
        """
        Test that an error is raised if we attempt to wrap when submodules are
        already FSDP.
        """
        wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, fsdp_init_mode=fsdp_init_mode)
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            wrapped_fsdp = wrapped_fsdp.cuda()

        with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"):
            mod = FSDP(wrapped_fsdp, fsdp_auto_wrap_policy=default_auto_wrap_policy)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]
    )
    @parametrize(
        "fsdp_init_mode",
        [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]
    )
    def test_main_wrap_api(self, cpu_offload, fsdp_init_mode):

        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            # they don't work together, expected
            return

        move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE

        class Nested(nn.Module):
            def __init__(self):
                super().__init__()
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)

            def forward(self, input):
                return self.nested_lin(input)

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin4 = Nested()

            def forward(self, input):
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))

        model = MyModel()
        wrapped_model = FSDP(
            model,
            fsdp_auto_wrap_policy=functools.partial(
                default_auto_wrap_policy,
                min_num_params=0,  # wrap all modules
            ),
            cpu_offload=cpu_offload,
        )
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            wrapped_model = wrapped_model.cuda()

        modules = [
            wrapped_model,
            wrapped_model.module.lin1,
            wrapped_model.module.lin2,
            wrapped_model.module.lin3,
            wrapped_model.module.lin4,
            # Nested FSDP
            wrapped_model.module.lin4.module.nested_lin,
        ]
        for module in modules:
            self.assertTrue(isinstance(module, FSDP))
            self._check_cpu_offload(module, cpu_offload)
        # Run model a few times for sanity check.
        optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
        inp = torch.ones(1).cuda()
        for _ in range(6):
            optim.zero_grad()
            loss = wrapped_model(inp).sum()
            loss.backward()
            optim.step()
Esempio n. 2
0
class TestAutoWrap(TestCase):
    def setUp(self) -> None:
        super().setUp()
        # For all the tests here, we use a fake group
        self.process_group = DummyProcessGroup(rank=0, size=1)

    @parametrize("wrap_method", [WrapMethod.FSDP_CTOR, WrapMethod.WRAP_API])
    def test_wrap(self, wrap_method):
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
                layer = wrap(nn.Linear(5, 5))
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            layer = FSDP(
                nn.Linear(5, 5),
                process_group=self.process_group,
                fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1)
            )
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, self.process_group.rank())
        self.assertEqual(layer.world_size, self.process_group.size())

    def test_wrap_disabled_outside_context(self):
        layer = wrap(nn.Linear(5, 5))
        self.assertTrue(isinstance(layer, nn.Linear))

    def test_wrap_override_defaults(self):
        new_process_group = DummyProcessGroup(rank=0, size=2)
        with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
            layer = wrap(nn.Linear(5, 5), process_group=new_process_group)
        self.assertTrue(isinstance(layer, FSDP))
        self.assertEqual(layer.rank, 0)
        self.assertEqual(layer.world_size, 2)

    def test_auto_wrap_api(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(
            sequential,
            process_group=self.process_group,
            fsdp_auto_wrap_policy=my_auto_wrap_policy
        )

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)


    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )

        model = FSDP(
            sequential,
            process_group=self.process_group,
            fsdp_auto_wrap_policy=my_auto_wrap_policy
        )

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))

    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))

    def test_auto_wrap_preset_force_leaf(self):
        """
        Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
        default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
        """
        sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
        self.assertTrue(isinstance(model.module[0], FSDP))
        # Assert children of multihead attention are not wrapped
        self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
        self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))

    def test_auto_wrap_preset_force_leaf_custom(self):
        """
        Test to ensure force-leaf modules are not wrapped.
        """
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy,
            min_num_params=40,
            force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union(
                {nn.Linear}
            ),
        )
        sequential = nn.Sequential(
            nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
        # Model was wrapped in FSDP as no inner modules were wrapped.
        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model.module[0], nn.Linear))
        self.assertTrue(isinstance(model.module[1], nn.ModuleList))

    @unittest.skipIf(not torch.cuda.is_available(), "Test Requires CUDA")
    @parametrize("fsdp_init_mode", [FSDPInitMode.CUDA_BEFORE, FSDPInitMode.CUDA_AFTER])
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=False), CPUOffload(offload_params=True)]
    )
    def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload):
        # CPU offload and CUDA after don't work together as expected.
        if (
            cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        ):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                default_auto_wrap_policy, min_num_params=40
            )
            model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
Esempio n. 3
0
class TestParityWithDDP(FSDPTest):
    """
    Compare losses and parameter values after several updates when using
    PyTorch DDP vs. FullyShardedDataParallel.
    """
    def _get_init_modes_for_test(self, cpu_offload):
        modes = [FSDPInitMode.CUDA_AFTER, FSDPInitMode.CUDA_BEFORE]
        # Note that FSDPInitMode.CUDA_NEVER works currently only with CPU
        # offload as we explicitly bring the param back to CUDA device. In
        # general, it will not work since we try to all_gather p.data which is
        # on CPU but NCCL only supports GPU.
        if cpu_offload.offload_params:
            modes.append(FSDPInitMode.CUDA_NEVER)

        return modes

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_nested_wrapped_model(self, cpu_offload):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(NestedWrappedModule,
                                             fsdp_init_mode=fsdp_init_mode,
                                             cpu_offload=cpu_offload)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_nested_all_wrapped_model(self, cpu_offload):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModule,
                                             wrap_everything=True)
                self._test_identical_outputs(model_fn,
                                             fsdp_init_mode=fsdp_init_mode,
                                             cpu_offload=cpu_offload)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_transformer_parameterized(self, cpu_offload):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(TransformerWithSharedParams,
                                             fsdp_init_mode=fsdp_init_mode,
                                             cpu_offload=cpu_offload)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_delayed_optim_step(self, cpu_offload):
        # We use a model with a long CUDA delay right before the optimizer step.
        # This tests our streams logic, and that we don't start the allgather
        # until after the optimization step completes.
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModuleWithDelay,
                                             delay_after_loss_ms=250)
                self._test_identical_outputs(model_fn,
                                             fsdp_init_mode=fsdp_init_mode,
                                             cpu_offload=cpu_offload)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_delayed_reduce_scatter(self, cpu_offload):
        # We insert a delay in the torch.distributed._reduce_scatter_base op, so that
        # the post_backward_stream takes much longer than the backward pass.
        # This tests that we properly block at the end of the backward pass for
        # the reductions to finish.
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(NestedWrappedModuleWithDelay,
                                             delay_before_reduction_ms=250)
                self._test_identical_outputs(model_fn,
                                             fsdp_init_mode=fsdp_init_mode,
                                             cpu_offload=cpu_offload)

    def _dummy_ddp_fn(self, model):
        return DummyDDP(model)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_mixture_of_experts(self, cpu_offload):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                self._test_identical_outputs(
                    MixtureOfExperts,
                    # MixtureOfExperts implements custom reduce logic, so the reference
                    # behavior should use that logic instead of PyTorch DDP.
                    ref_ddp_fn=self._dummy_ddp_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                )

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True),
         CPUOffload(offload_params=False)])
    def test_mixture_of_experts_with_delay_before_free(self, cpu_offload):
        init_modes = self._get_init_modes_for_test(cpu_offload)
        for fsdp_init_mode in init_modes:
            with self.subTest(fsdp_init_mode=fsdp_init_mode):
                model_fn = functools.partial(MixtureOfExperts,
                                             delay_before_free_ms=250)
                self._test_identical_outputs(
                    model_fn,
                    ref_ddp_fn=self._dummy_ddp_fn,
                    fsdp_init_mode=fsdp_init_mode,
                    cpu_offload=cpu_offload,
                )
Esempio n. 4
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                *args,
                                ref_ddp_fn=None,
                                num_steps=2,
                                fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
                                lr=0.01,
                                cpu_offload=CPUOffload(),
                                **kwargs):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr,
                                                 fsdp_cpu_offload=cpu_offload)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(group=group,
                                  wrap_fsdp=True,
                                  fsdp_init_mode=fsdp_init_mode,
                                  cpu_offload=cpu_offload)
        except Exception as e:
            raise ValueError(
                f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(model, cpu_offload=cpu_offload)
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"),
                                 f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (self.assertRaisesRegex(AssertionError,
                                      "Expected param to be on CPU")
               if only_check_err else suppress())
        with ctx:
            shard_loss = self._train_for_several_steps(
                model,
                num_steps,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual({torch.device("cpu")}, device_set,
                             f"Got device set {device_set}")
        get_full_params(model)
        shard_full_params = list(model.parameters())

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
Esempio n. 5
0
class TestFSDPCheckpoint(FSDPTest):

    class SequentialModule(nn.Module):
        def __init__(self, checkpoint_layer=False, wrap_fsdp=False, *fsdp_args, **fsdp_kwargs):
            torch.manual_seed(0)
            torch.cuda.manual_seed(0)
            super().__init__()
            l1 = nn.Linear(3, 3).cuda()
            l2 = nn.Linear(3, 3).cuda()
            l3 = nn.Linear(3, 3).cuda()

            if checkpoint_layer:
                l1 = checkpoint_wrapper(l1)
                l2 = checkpoint_wrapper(l2)
                l3 = checkpoint_wrapper(l3)

            fsdp_wrapper = partial(
                _maybe_wrap_fsdp,
                wrap_fsdp=wrap_fsdp,
                *fsdp_args,
                **fsdp_kwargs
            )
            self.ffn = nn.Sequential(
                fsdp_wrapper(l1),
                fsdp_wrapper(l2),
                fsdp_wrapper(l3),
            )

        def forward(self, x):
            return self.ffn(x)


    def _verify_parity(self, losses, outputs, models):
        assert losses
        assert outputs
        assert models

        for (l, o) in zip(losses[1:], outputs[1:]):
            self.assertEqual(losses[0], l)
            self.assertEqual(outputs[0], o)

        # Verify grads
        ref_model = models[0]
        ref_grads = [p.grad for p in ref_model.parameters()]
        for m in models[1:]:
            grads = [p.grad for p in m.parameters()]
            for ref_g, g in zip(ref_grads, grads):
                self.assertEqual(ref_g, g)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
    )
    def test_checkpoint_fsdp_wrapping(self, cpu_offload):
        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
        ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
            TestFSDPCheckpoint.SequentialModule(
                wrap_fsdp=True, cpu_offload=cpu_offload
            )
        )
        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
            checkpoint_layer=True, wrap_fsdp=True, cpu_offload=cpu_offload
        )

        baseline = TestFSDPCheckpoint.SequentialModule(
            wrap_fsdp=True, cpu_offload=cpu_offload
        )

        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.
        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [
            ckpt_sequential_wrapped_fsdp,
            inner_ckpt,
            baseline
        ]

        for _ in range(2):
            losses = []
            outputs = []
            for m in models:
                out = m(inp)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)]
    )
    def test_basic_checkpoint_end_to_end(self, cpu_offload):
        seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(FSDP(deepcopy(seq), cpu_offload=cpu_offload))
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(checkpoint_wrapper(deepcopy(seq)), cpu_offload=cpu_offload)
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint
        ]

        for _ in range(6):
            losses = []
            outputs = []
            for m in models:
                if m == fsdp_call_checkpoint:
                    out = checkpoint(m, inp)
                else:
                    out = m(inp)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)
class TestFSDPCheckpoint(FSDPTest):
    class SequentialModule(nn.Module):
        def __init__(
            self,
            checkpoint_layer=False,
            offload_activations=False,
            wrap_fsdp=False,
            *fsdp_args,
            **fsdp_kwargs,
        ):
            torch.manual_seed(0)
            torch.cuda.manual_seed(0)
            super().__init__()
            l1 = nn.Linear(3, 3).cuda()
            l2 = nn.Linear(3, 3).cuda()
            l3 = nn.Linear(3, 3).cuda()

            if checkpoint_layer:
                ckpt_wrapper = partial(
                    checkpoint_wrapper, offload_to_cpu=offload_activations
                )

                l1 = ckpt_wrapper(l1)
                l2 = ckpt_wrapper(l2)
                l3 = ckpt_wrapper(l3)

            fsdp_wrapper = partial(
                _maybe_wrap_fsdp, wrap_fsdp=wrap_fsdp, *fsdp_args, **fsdp_kwargs
            )
            self.ffn = nn.Sequential(
                fsdp_wrapper(l1),
                fsdp_wrapper(l2),
                fsdp_wrapper(l3),
            )

        def forward(self, x):
            return self.ffn(x)

    def _verify_parity(self, losses, outputs, models):
        assert losses
        assert outputs
        assert models

        for (l, o) in zip(losses[1:], outputs[1:]):
            self.assertEqual(losses[0], l)
            self.assertEqual(outputs[0], o)

        # Verify grads
        ref_model = models[0]
        ref_grads = [p.grad for p in ref_model.parameters()]
        for m in models[1:]:
            grads = [p.grad for p in m.parameters()]
            for ref_g, g in zip(ref_grads, grads):
                self.assertEqual(ref_g, g)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("offload_activations", [True, False])
    def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
        ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
            TestFSDPCheckpoint.SequentialModule(
                wrap_fsdp=True, cpu_offload=cpu_offload
            ),
            offload_to_cpu=offload_activations,
        )
        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
            checkpoint_layer=True,
            offload_activations=offload_activations,
            wrap_fsdp=True,
            cpu_offload=cpu_offload,
        )

        baseline = TestFSDPCheckpoint.SequentialModule(
            wrap_fsdp=True, cpu_offload=cpu_offload
        )

        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.
        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]

        offload_to_cpu_event = "Memcpy DtoH"

        for i in range(2):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != baseline and i == 0 and offload_activations
                profiler_ctx = (
                    torch.profiler.profile(use_cuda=True)
                    if check_offload
                    else contextlib.suppress()
                )
                with profiler_ctx as prof:
                    out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(
                        offload_to_cpu_event in name for name in event_names
                    )
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)

    @skip_if_lt_x_gpu(2)
    @parametrize(
        "cpu_offload",
        [CPUOffload(offload_params=True), CPUOffload(offload_params=False)],
    )
    @parametrize("offload_activations", [True, False])
    def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
        seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(
            FSDP(deepcopy(seq), cpu_offload=cpu_offload),
            offload_to_cpu=offload_activations,
        )
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(
            checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
            cpu_offload=cpu_offload,
        )
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint,
        ]

        offload_to_cpu_event = "Memcpy DtoH"

        for i in range(6):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                profiler_ctx = (
                    torch.profiler.profile(use_cuda=True)
                    if check_offload
                    else contextlib.suppress()
                )
                with profiler_ctx as prof:
                    if m == fsdp_call_checkpoint:
                        offload_ctx = (
                            torch.autograd.graph.save_on_cpu(pin_memory=True)
                            if offload_activations
                            else contextlib.suppress()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(
                        offload_to_cpu_event in name for name in event_names
                    )
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)