예제 #1
0
    def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
        global _save_on_cpu_called
        with patch_save_on_cpu(get_patched_save_on_cpu()):
            seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
            # Runs FSDP with no checkpointing
            fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
            # Runs checkpoint-wrapped FSDP
            checkpointed_fsdp = checkpoint_wrapper(
                FSDP(deepcopy(seq), cpu_offload=cpu_offload),
                offload_to_cpu=offload_activations,
            )
            # Runs FSDP-wrapped checkpointed module
            fsdp_wrapped_checkpoint = FSDP(
                checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
                cpu_offload=cpu_offload,
            )
            # Runs FSDP with manual calls to checkpoint.
            fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
            # note that reentrant-based checkpointing requires inputs to have grad
            # flag set.

            inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

            models = [
                fsdp_only_seq,
                checkpointed_fsdp,
                fsdp_wrapped_checkpoint,
                fsdp_call_checkpoint,
            ]
            # Ensure _save_on_cpu is not yet called
            self.assertFalse(_save_on_cpu_called)
            for i in range(6):
                losses = []
                outputs = []
                for m in models:
                    check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                    if m == fsdp_call_checkpoint:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        offload_ctx = (
                            get_patched_save_on_cpu()(pin_memory=True)
                            if offload_activations
                            else contextlib.suppress()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        out = m(inp)

                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)
                    _save_on_cpu_called = False

                self._verify_parity(losses, outputs, models)
예제 #2
0
    def test_load_activation_checkpointed_module(self):
        lin = nn.Linear(10, 10, bias=False)
        lin = checkpoint_wrapper(
            lin,
            checkpoint_fn=checkpoint,
            # checkpoint kwargs
            use_reentrant=True,
            preserve_rng_state=False,
        )
        state_dict = deepcopy(lin.state_dict())
        # Load into non-checkpoint wrapped linear module
        lin_new = nn.Linear(10, 10, bias=False)
        lin_new.load_state_dict(state_dict)
        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
            self.assertEqual(p1, p2)
            self.assertTrue(torch.allclose(p1, p2))

        # Load non-checkpoint wrapped module into checkpoint wrapped one
        # Make params different
        for p in lin_new.parameters():
            with torch.no_grad():
                p.add_(0.5)

        state_dict = deepcopy(lin_new.state_dict())
        # Verify checkpoint wrapped linear can load unwrapped linear
        lin.load_state_dict(state_dict)
        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
            self.assertEqual(p1, p2)
예제 #3
0
 def _get_simple_model(self,
                       *fsdp_args,
                       checkpoint_wrap=False,
                       **fsdp_kwargs):
     lin = nn.Linear(10, 10, bias=False).cuda()
     if checkpoint_wrap:
         lin = checkpoint_wrapper(lin)
     model = FSDP(lin, *fsdp_args, **fsdp_kwargs)
     return model
예제 #4
0
 def _get_simple_nested_model(self,
                              *fsdp_args,
                              wrap=True,
                              checkpoint_wrap=False,
                              **fsdp_kwargs):
     if wrap:
         lin1 = nn.Linear(10, 10, bias=False).cuda()
         lin2 = nn.Linear(10, 10, bias=False).cuda()
         if checkpoint_wrap:
             lin1 = checkpoint_wrapper(lin1)
             lin2 = checkpoint_wrapper(lin2)
         seq = nn.Sequential(FSDP(lin1, *fsdp_args, **fsdp_kwargs), lin2)
         if checkpoint_wrap:
             seq = checkpoint_wrapper(seq)
         model = FSDP(seq, *fsdp_args, **fsdp_kwargs)
     else:
         model = nn.Sequential(
             nn.Linear(10, 10, bias=False).cuda(),
             nn.Linear(10, 10, bias=False).cuda())
     return model
예제 #5
0
    def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
        ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
            TestFSDPCheckpoint.SequentialModule(wrap_fsdp=True,
                                                cpu_offload=cpu_offload),
            offload_to_cpu=offload_activations,
        )
        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
            checkpoint_layer=True,
            offload_activations=offload_activations,
            wrap_fsdp=True,
            cpu_offload=cpu_offload,
        )

        baseline = TestFSDPCheckpoint.SequentialModule(wrap_fsdp=True,
                                                       cpu_offload=cpu_offload)

        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.
        inp = torch.randn(10,
                          3,
                          device=torch.cuda.current_device(),
                          requires_grad=True)

        models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]

        offload_to_cpu_event = "Memcpy DtoH" if torch.version.cuda else "CopyDeviceToHost"

        for i in range(2):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != baseline and i == 0 and offload_activations
                profiler_ctx = (torch.profiler.profile(
                    use_cuda=True) if check_offload else contextlib.suppress())
                with profiler_ctx as prof:
                    out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(offload_to_cpu_event in name
                                          for name in event_names)
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)
예제 #6
0
    def test_checkpoint_fsdp_wrapping(self, cpu_offload, offload_activations):
        # Test checkpoint(FSDP(layer1), FSDP(layer2), ....)
        ckpt_sequential_wrapped_fsdp = checkpoint_wrapper(
            TestFSDPCheckpoint.SequentialModule(
                wrap_fsdp=True, cpu_offload=cpu_offload
            ),
            offload_to_cpu=offload_activations,
        )
        # Test FSDP(checkpoint(layer1)), FSDP(checkpoint(layer2)), ....
        inner_ckpt = TestFSDPCheckpoint.SequentialModule(
            checkpoint_layer=True,
            offload_activations=offload_activations,
            wrap_fsdp=True,
            cpu_offload=cpu_offload,
        )

        baseline = TestFSDPCheckpoint.SequentialModule(
            wrap_fsdp=True, cpu_offload=cpu_offload
        )

        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.
        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        global _save_on_cpu_called
        models = [ckpt_sequential_wrapped_fsdp, inner_ckpt, baseline]
        with patch_save_on_cpu(get_patched_save_on_cpu()):
            for i in range(2):
                losses = []
                outputs = []
                for m in models:
                    check_offload = m != baseline and i == 0 and offload_activations
                    if check_offload:
                        self.assertFalse(_save_on_cpu_called)
                    out = m(inp)
                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                        _save_on_cpu_called = False
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)

                self._verify_parity(losses, outputs, models)
예제 #7
0
    def test_load_activation_checkpointed_module(self):
        # TODO: move this tests to checkpoint_wrapper tests once there is a dedicated
        # test suite for them: https://github.com/pytorch/pytorch/issues/77478.
        lin = nn.Linear(10, 10, bias=False).cuda()
        lin = checkpoint_wrapper(lin)
        state_dict = deepcopy(lin.state_dict())
        # Load into non-checkpoint wrapped linear module
        lin_new = nn.Linear(10, 10, bias=False).cuda()
        lin_new.load_state_dict(state_dict)
        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
            self.assertEqual(p1, p2)

        # Load non-checkpoint wrapped module into checkpoint wrapped one
        # Make params different
        for p in lin_new.parameters():
            with torch.no_grad():
                p.add_(0.5)

        state_dict = deepcopy(lin_new.state_dict())
        # Verify checkpoint wrapped linear can load unwrapped linear
        lin.load_state_dict(state_dict)
        print(type(lin))
        for p1, p2 in zip(lin.parameters(), lin_new.parameters()):
            self.assertEqual(p1, p2)
예제 #8
0
 def test_fqn(self):
     lin = nn.Linear(10, 10, bias=False)
     lin = checkpoint_wrapper(lin)
     state_dict = lin.state_dict()
     for fqn, _ in lin.named_parameters():
         self.assertTrue(fqn in state_dict, msg=f"{fqn} not in state_dict.")
예제 #9
0
    def test_basic_checkpoint_end_to_end(self, cpu_offload,
                                         offload_activations):
        seq = TestFSDPCheckpoint.SequentialModule().to(
            torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(
            FSDP(deepcopy(seq), cpu_offload=cpu_offload),
            offload_to_cpu=offload_activations,
        )
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(
            checkpoint_wrapper(deepcopy(seq),
                               offload_to_cpu=offload_activations),
            cpu_offload=cpu_offload,
        )
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10,
                          3,
                          device=torch.cuda.current_device(),
                          requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint,
        ]

        offload_to_cpu_event = "Memcpy DtoH" if torch.version.cuda else "CopyDeviceToHost"

        for i in range(6):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                profiler_ctx = (torch.profiler.profile(
                    use_cuda=True) if check_offload else contextlib.suppress())
                with profiler_ctx as prof:
                    if m == fsdp_call_checkpoint:
                        offload_ctx = (torch.autograd.graph.save_on_cpu(
                            pin_memory=True) if offload_activations else
                                       contextlib.suppress())
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(offload_to_cpu_event in name
                                          for name in event_names)
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)