Example #1
0
 def test_auto_wrap_preset_force_leaf_custom(self, wrap_method):
     """
     Test to ensure force-leaf modules are not wrapped.
     """
     my_auto_wrap_policy = functools.partial(
         default_auto_wrap_policy,
         min_num_params=40,
         force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.
         union({nn.Linear}),
     )
     sequential = nn.Sequential(nn.Linear(10, 10),
                                nn.ModuleList([nn.Linear(10, 10)]))
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(
                 auto_wrap_policy=my_auto_wrap_policy,
                 wrapper_cls=FSDP,
                 process_group=self.process_group,
         ):
             model = auto_wrap(sequential)
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         model = FSDP(sequential,
                      process_group=self.process_group,
                      fsdp_auto_wrap_policy=my_auto_wrap_policy)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.ModuleList))
Example #2
0
    def test_auto_wrap_preset_exclude_wrap(self, wrap_method):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)

        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        # Note that outermost module will be FSDP instance for FSDP_CTOR
        # approach, because we need to call the FSDP ctor so the returned obj
        # will be an FSDP instance. If we don't want to shard the outermost
        # module based on policy, we can apply the policy manually to the
        # outermost instance and skip the sharding.
        if wrap_method == WrapMethod.WRAP_API:
            self.assertTrue(isinstance(model, nn.ModuleList))
        else:
            self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))
Example #3
0
    def test_auto_wrap_smoke_test(self, fsdp_init_mode, cpu_offload):
        # CPU offload and CUDA after don't work together as expected.
        if (
            cpu_offload.offload_params and fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        ):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())
        torch.distributed.init_process_group(backend="nccl", rank=0, world_size=1)

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = fsdp_init_mode == FSDPInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                default_auto_wrap_policy, min_num_params=40
            )
            model = FSDP(sequential, cpu_offload=cpu_offload, fsdp_auto_wrap_policy=my_auto_wrap_policy)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()
            del os.environ["MASTER_ADDR"]
            del os.environ["MASTER_PORT"]
Example #4
0
    def test_main_wrap_api(self, cpu_offload, fsdp_init_mode):

        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            # they don't work together, expected
            return

        move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE

        class Nested(nn.Module):
            def __init__(self):
                super().__init__()
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)

            def forward(self, input):
                return self.nested_lin(input)

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False), move_to_cuda)
                self.lin4 = Nested()

            def forward(self, input):
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))

        model = MyModel()
        wrapped_model = FSDP(
            model,
            fsdp_auto_wrap_policy=functools.partial(
                default_auto_wrap_policy,
                min_num_params=0,  # wrap all modules
            ),
            cpu_offload=cpu_offload,
        )
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            wrapped_model = wrapped_model.cuda()

        modules = [
            wrapped_model,
            wrapped_model.module.lin1,
            wrapped_model.module.lin2,
            wrapped_model.module.lin3,
            wrapped_model.module.lin4,
            # Nested FSDP
            wrapped_model.module.lin4.module.nested_lin,
        ]
        for module in modules:
            self.assertTrue(isinstance(module, FSDP))
            self._check_cpu_offload(module, cpu_offload)
        # Run model a few times for sanity check.
        optim = torch.optim.SGD(wrapped_model.parameters(), lr=1e-2, momentum=0.9)
        inp = torch.ones(1).cuda()
        for _ in range(6):
            optim.zero_grad()
            loss = wrapped_model(inp).sum()
            loss.backward()
            optim.step()
Example #5
0
 def __init__(self, nested):
     super().__init__()
     # TODO: test the various init modes.
     move_to_cuda = fsdp_init_mode == FSDPInitMode.CUDA_BEFORE
     # if nested=True, the FSDP module will be nested one layer deep
     # and we should pick that up.
     if nested:
         self.lin1 = nn.Sequential(
             _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),
             FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)),
         )
     else:
         self.lin1 = FSDP(
             _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda)
         )
     self.lin2 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
     self.lin3 = FSDP(_maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
Example #6
0
    def test_error_already_wrapped(self, nested, fsdp_init_mode):
        """
        Test that an error is raised if we attempt to wrap when submodules are
        already FSDP.
        """
        wrapped_fsdp = self._get_already_wrapped_fsdp(nested=nested, fsdp_init_mode=fsdp_init_mode)
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            wrapped_fsdp = wrapped_fsdp.cuda()

        with self.assertRaisesRegex(ValueError, "to NOT be FullyShardedDataParallel"):
            mod = FSDP(wrapped_fsdp, fsdp_auto_wrap_policy=default_auto_wrap_policy)
Example #7
0
    def test_auto_wrap_api(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     fsdp_auto_wrap_policy=my_auto_wrap_policy)

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
Example #8
0
    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(
            default_auto_wrap_policy, min_num_params=40
        )
        model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))
Example #9
0
 def test_wrap(self, wrap_method):
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, process_group=self.process_group):
             layer = wrap(nn.Linear(5, 5))
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         layer = FSDP(
             nn.Linear(5, 5),
             process_group=self.process_group,
             fsdp_auto_wrap_policy=functools.partial(default_auto_wrap_policy, min_num_params=1)
         )
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
Example #10
0
 def test_auto_wrap_preset_force_leaf(self):
     """
     Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
     default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
     """
     sequential = nn.Sequential(nn.Linear(10, 10), nn.MultiheadAttention(100, 1))
     my_auto_wrap_policy = functools.partial(
         default_auto_wrap_policy, min_num_params=40
     )
     model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model.module[0], FSDP))
     # Assert children of multihead attention are not wrapped
     self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
     self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
Example #11
0
    def test_basic_checkpoint_end_to_end(self, cpu_offload):
        seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(FSDP(deepcopy(seq), cpu_offload=cpu_offload))
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(checkpoint_wrapper(deepcopy(seq)), cpu_offload=cpu_offload)
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint
        ]

        for _ in range(6):
            losses = []
            outputs = []
            for m in models:
                if m == fsdp_call_checkpoint:
                    out = checkpoint(m, inp)
                else:
                    out = m(inp)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)
Example #12
0
    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the default_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)

        model = FSDP(sequential,
                     process_group=self.process_group,
                     fsdp_auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))
Example #13
0
 def test_auto_wrap_preset_force_leaf_custom(self):
     """
     Test to ensure force-leaf modules are not wrapped.
     """
     my_auto_wrap_policy = functools.partial(
         default_auto_wrap_policy,
         min_num_params=40,
         force_leaf_modules=default_auto_wrap_policy.FORCE_LEAF_MODULES.union(
             {nn.Linear}
         ),
     )
     sequential = nn.Sequential(
         nn.Linear(10, 10), nn.ModuleList([nn.Linear(10, 10)])
     )
     model = FSDP(sequential, process_group=self.process_group, fsdp_auto_wrap_policy=my_auto_wrap_policy)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.ModuleList))
Example #14
0
    def test_auto_wrap_foo(self, wrap_method):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
Example #15
0
 def test_auto_wrap_preset_force_leaf(self, wrap_method):
     """
     Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
     default_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
     """
     sequential = nn.Sequential(nn.Linear(10, 10),
                                nn.MultiheadAttention(100, 1))
     my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                             min_num_params=40)
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP,
                          process_group=self.process_group):
             model = auto_wrap(sequential,
                               auto_wrap_policy=my_auto_wrap_policy)
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         model = FSDP(sequential,
                      process_group=self.process_group,
                      fsdp_auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model.module[0], FSDP))
     # Assert children of multihead attention are not wrapped
     self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
     self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
Example #16
0
    def test_auto_wrap_preset_exclude_wrap_include_children(self, wrap_method):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(default_auto_wrap_policy,
                                                min_num_params=40)
        if wrap_method == WrapMethod.WRAP_API:
            with enable_wrap(wrapper_cls=FSDP,
                             process_group=self.process_group):
                model = auto_wrap(sequential,
                                  auto_wrap_policy=my_auto_wrap_policy)
        else:
            assert wrap_method == WrapMethod.FSDP_CTOR
            model = FSDP(sequential,
                         process_group=self.process_group,
                         fsdp_auto_wrap_policy=my_auto_wrap_policy)

        if wrap_method == WrapMethod.WRAP_API:
            self.assertTrue(isinstance(model, nn.ModuleList))
        else:
            self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))
    def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
        seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(
            FSDP(deepcopy(seq), cpu_offload=cpu_offload),
            offload_to_cpu=offload_activations,
        )
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(
            checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
            cpu_offload=cpu_offload,
        )
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint,
        ]

        offload_to_cpu_event = "Memcpy DtoH"

        for i in range(6):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                profiler_ctx = (
                    torch.profiler.profile(use_cuda=True)
                    if check_offload
                    else contextlib.suppress()
                )
                with profiler_ctx as prof:
                    if m == fsdp_call_checkpoint:
                        offload_ctx = (
                            torch.autograd.graph.save_on_cpu(pin_memory=True)
                            if offload_activations
                            else contextlib.suppress()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(
                        offload_to_cpu_event in name for name in event_names
                    )
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)