Exemplo n.º 1
0
    def test_basic_checkpoint_end_to_end(self, cpu_offload, offload_activations):
        global _save_on_cpu_called
        with patch_save_on_cpu(get_patched_save_on_cpu()):
            seq = TestFSDPCheckpoint.SequentialModule().to(torch.cuda.current_device())
            # Runs FSDP with no checkpointing
            fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
            # Runs checkpoint-wrapped FSDP
            checkpointed_fsdp = checkpoint_wrapper(
                FSDP(deepcopy(seq), cpu_offload=cpu_offload),
                offload_to_cpu=offload_activations,
            )
            # Runs FSDP-wrapped checkpointed module
            fsdp_wrapped_checkpoint = FSDP(
                checkpoint_wrapper(deepcopy(seq), offload_to_cpu=offload_activations),
                cpu_offload=cpu_offload,
            )
            # Runs FSDP with manual calls to checkpoint.
            fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
            # note that reentrant-based checkpointing requires inputs to have grad
            # flag set.

            inp = torch.randn(10, 3, device=torch.cuda.current_device(), requires_grad=True)

            models = [
                fsdp_only_seq,
                checkpointed_fsdp,
                fsdp_wrapped_checkpoint,
                fsdp_call_checkpoint,
            ]
            # Ensure _save_on_cpu is not yet called
            self.assertFalse(_save_on_cpu_called)
            for i in range(6):
                losses = []
                outputs = []
                for m in models:
                    check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                    if m == fsdp_call_checkpoint:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        offload_ctx = (
                            get_patched_save_on_cpu()(pin_memory=True)
                            if offload_activations
                            else contextlib.suppress()
                        )
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        # _save_on_cpu should not be called yet
                        self.assertFalse(_save_on_cpu_called)
                        out = m(inp)

                    if check_offload:
                        self.assertTrue(_save_on_cpu_called)
                    loss = out.sum()
                    loss.backward()
                    losses.append(loss)
                    outputs.append(out)
                    _save_on_cpu_called = False

                self._verify_parity(losses, outputs, models)
Exemplo n.º 2
0
 def test_auto_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
     sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
     ignored_modules = [sequential[1], sequential[2][0]]
     my_auto_wrap_policy = functools.partial(
         size_based_auto_wrap_policy,
         min_num_params=40,
     )
     fsdp_kwargs = {
         "process_group": self.process_group,
         "auto_wrap_policy": my_auto_wrap_policy,
         "ignored_modules": ignored_modules,
     }
     if wrap_method == WrapMethod.FSDP_CTOR:
         model = FSDP(sequential, **fsdp_kwargs)
     elif wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
             model = wrap(sequential)
     else:
         assert 0, f"Unsupported wrap method: {wrap_method}"
     # Since the 2nd linear (`sequential[1]`) is ignored, the wrapping
     # policy does not exceed the parameter threshold before the inner
     # sequential (`sequential[2]`) anymore; hence, it flattens
     # `sequential[0]` and `sequential[2][0]` into `model` and leaves
     # `sequential[1]` and `sequential[2][1]` as-is since they are ignored
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], nn.Sequential))
     self.assertTrue(isinstance(model.module[2][0], nn.Linear))
     self.assertTrue(isinstance(model.module[2][1], nn.Linear))
Exemplo n.º 3
0
 def _run_fsdp_one_iteration(self, norm_type, nested_fsdp, cpu_offload):
     """Test FSDP with clip grad norm."""
     fsdp_model = DeterministicModel(nested_fsdp, cpu_offload=cpu_offload)
     local_model = DeterministicModel(False)
     input = torch.rand(14, 2, device=self.rank)
     fsdp_model = FSDP(fsdp_model, cpu_offload=cpu_offload)
     self.assertTrue(len(input) >= self.world_size)
     out = local_model(input[:self.world_size])
     out.sum().backward()
     in_data = torch.tensor(input[self.rank], device=self.rank)
     out_fsdp = fsdp_model(in_data)
     out_fsdp.sum().backward()
     total_norms_fsdp = _collect_total_grad_norm_fsdp(
         fsdp_model, norm_type, self.rank)
     total_norms_local = _collect_total_grad_norm_local(
         local_model, norm_type)
     total_norms_local /= self.world_size
     norm_cap = total_norms_fsdp / 2.0
     self.assertEqual(total_norms_local, total_norms_fsdp)
     fsdp_model.clip_grad_norm_(norm_cap, norm_type=norm_type)
     nn_utils.clip_grad_norm_(local_model.parameters(),
                              norm_cap,
                              norm_type=norm_type)
     total_norms_after_clip_fsdp = _collect_total_grad_norm_fsdp(
         fsdp_model, norm_type, self.rank)
     total_norms_after_clip_local = _collect_total_grad_norm_local(
         local_model, norm_type)
     self.assertTrue(total_norms_after_clip_fsdp <= norm_cap)
     self.assertEqual(total_norms_after_clip_local,
                      total_norms_after_clip_fsdp)
Exemplo n.º 4
0
 def test_fsdp_calc_grad_norm(self, norm_type, nested_fsdp):
     """Test grad norm cal API."""
     model = FSDP(DeterministicModel(nested_fsdp))
     input = torch.rand(15, 2, device=self.rank)
     out = model(input)
     out.sum().backward()
     total_norm = _calc_grad_norm(model.params_with_grad, norm_type)
     total_norm_expected = _collect_total_grad_norm_local(model, norm_type)
     self.assertEqual(total_norm, total_norm_expected)
Exemplo n.º 5
0
    def test_bn_always_wrapped_individually(self):
        """
        Ensures that by using _or_policy with _wrap_batchnorm_individually, even
        if the other policy results in a module containing a BN unit being
        wrapped, the contained BN unit will still be individually wrapped.
        """
        class MyModule(nn.Module):
            def __init__(self):
                super().__init__()
                self.bn_container = BatchNormNet()

        def wrap_bn_container(module, recurse, *args, **kwargs):
            if recurse:
                return True
            return isinstance(module, BatchNormNet)

        my_policy = functools.partial(
            _or_policy,
            policies=[wrap_bn_container, _wrap_batchnorm_individually])
        mod = MyModule()
        fsdp = FSDP(mod, auto_wrap_policy=my_policy)

        # Wrapping should be FSDP(FSDP(BatchNormNet(FSDP(BN))))
        # and not FSDP(FSDP(BatchNormNet(BN))) (in the latter the inner
        # BN is not individually wrapped.)

        for bn in [
                fsdp.bn_container.bn1, fsdp.bn_container.bn2,
                fsdp.bn_container.bn3, fsdp.bn_container.sync_bn
        ]:
            self.assertTrue(isinstance(bn, FSDP))

        # if we just wrapped BN container, individual batchnorms are not
        # wrapped.
        mod = MyModule()
        fsdp = FSDP(mod, auto_wrap_policy=wrap_bn_container)
        self.assertTrue(isinstance(mod.bn_container, FSDP))
        for bn in [
                fsdp.bn_container.bn1, fsdp.bn_container.bn2,
                fsdp.bn_container.bn3, fsdp.bn_container.sync_bn
        ]:
            self.assertFalse(isinstance(bn, FSDP))
Exemplo n.º 6
0
 def __init__(self, nested):
     super().__init__()
     # TODO: test the various init modes.
     move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE
     # if nested=True, the FSDP module will be nested one layer deep
     # and we should pick that up.
     if nested:
         self.lin1 = nn.Sequential(
             _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda),
             FSDP(
                 _maybe_cuda(fn_self._get_linear(1, 1),
                             move_to_cuda)),
         )
     else:
         self.lin1 = FSDP(
             _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
     self.lin2 = FSDP(
         _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
     self.lin3 = FSDP(
         _maybe_cuda(fn_self._get_linear(1, 1), move_to_cuda))
Exemplo n.º 7
0
 def test_always_wrap(self):
     """
     Test to ensure that if `always_wrap_policy` is
     passed into FSDP, all submodules are wrapped.
     """
     seq = TestFSDPWrap.NestedSequentialModel.get_model(cuda=True)
     model = FSDP(seq,
                  process_group=self.process_group,
                  auto_wrap_policy=always_wrap_policy)
     TestFSDPWrap.NestedSequentialModel.verify_model_all_wrapped(
         self, model)
Exemplo n.º 8
0
    def test_error_already_wrapped(self, nested, fsdp_init_mode):
        """
        Test that an error is raised if we attempt to wrap when submodules are
        already FSDP.
        """
        wrapped_fsdp = self._get_already_wrapped_fsdp(
            nested=nested, fsdp_init_mode=fsdp_init_mode)
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            wrapped_fsdp = wrapped_fsdp.cuda()

        with self.assertRaisesRegex(ValueError,
                                    "to NOT be FullyShardedDataParallel"):
            mod = FSDP(wrapped_fsdp, auto_wrap_policy=default_auto_wrap_policy)
Exemplo n.º 9
0
    def test_auto_wrap_api(self):
        """
        Test to ensure with auto wrap, we wrap child modules correctly based on the min_num_params.
        ``nn.Linear(5, 5)`` does not exceed the bucket size, but combined they do.
        """
        sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
Exemplo n.º 10
0
 def test_wrap(self, wrap_method):
     if wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP,
                          process_group=self.process_group):
             layer = wrap(nn.Linear(5, 5))
     else:
         assert wrap_method == WrapMethod.FSDP_CTOR
         layer = FSDP(nn.Linear(5, 5),
                      process_group=self.process_group,
                      auto_wrap_policy=functools.partial(
                          size_based_auto_wrap_policy, min_num_params=1))
     self.assertTrue(isinstance(layer, FSDP))
     self.assertEqual(layer.rank, self.process_group.rank())
     self.assertEqual(layer.world_size, self.process_group.size())
Exemplo n.º 11
0
    def test_auto_wrap_smoke_test(self, cuda_init_mode, cpu_offload,
                                  use_device_id):
        # CPU offload and CUDA after don't work together as expected.
        if (cpu_offload.offload_params
                and cuda_init_mode == CUDAInitMode.CUDA_AFTER):
            return

        device = torch.device("cuda")
        torch.cuda.set_device(0)
        device_id = (torch.device("cuda", torch.cuda.current_device())
                     if use_device_id else None)

        # Random port in case the next test run quickly, same port would cause conflict.
        os.environ["MASTER_ADDR"] = "localhost"
        os.environ["MASTER_PORT"] = str(find_free_port())

        file_name = tempfile.NamedTemporaryFile(delete=False).name
        torch.distributed.init_process_group(
            backend="nccl",
            init_method=f"{FILE_SCHEMA}_{file_name}",
            rank=0,
            world_size=1,
        )

        # NOTE: We move model to CUDA after init with FSDP to simulate real use
        # cases where full model cannot be loaded onto GPU, but their shards can.
        cuda_after_init = cuda_init_mode == CUDAInitMode.CUDA_AFTER
        try:
            sequential = TestFSDPWrap.NestedSequentialModel.get_model(
                cuda=(not cuda_after_init))
            my_auto_wrap_policy = functools.partial(
                size_based_auto_wrap_policy, min_num_params=40)
            model = FSDP(sequential,
                         cpu_offload=cpu_offload,
                         auto_wrap_policy=my_auto_wrap_policy,
                         device_id=device_id)
            TestFSDPWrap.NestedSequentialModel.verify_model(self, model)
            if cuda_after_init:
                model = model.cuda()
            input = torch.rand((1, 5), dtype=torch.float).to(device)
            output = model(input)
            loss = F.mse_loss(input, output)
            loss.backward()
        finally:
            torch.distributed.destroy_process_group()

        try:
            os.remove(file_name)
        except FileNotFoundError:
            pass
Exemplo n.º 12
0
    def test_auto_wrap_preset_exclude_wrap_include_children(self):
        """
        Test to ensure excluded modules are not wrapped, but children are if param size is greater than
        min_num_params
        """
        sequential = nn.ModuleList([nn.Linear(10, 10)])
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)
        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], FSDP))
Exemplo n.º 13
0
 def test_transformer_auto_wrap_policy(self):
     model = TransformerWithSharedParams(group=self.process_group)
     my_auto_wrap_policy = functools.partial(transformer_auto_wrap_policy,
                                             transformer_layer_cls={
                                                 TransformerEncoderLayer,
                                                 TransformerDecoderLayer
                                             })
     fsdp_model = FSDP(model,
                       process_group=self.process_group,
                       auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(fsdp_model, FSDP))
     for layer in fsdp_model.module.module.transformer.encoder.layers:
         self.assertTrue(isinstance(layer, FSDP))
     for layer in fsdp_model.module.module.transformer.decoder.layers:
         self.assertTrue(isinstance(layer, FSDP))
Exemplo n.º 14
0
    def test_wrap_batchnorm_individually(self, use_or_policy):
        def never_wrap_policy(*args, **kwargs):
            return False

        policy = (functools.partial(
            _or_policy,
            policies=[never_wrap_policy, _wrap_batchnorm_individually])
                  if use_or_policy else _wrap_batchnorm_individually)
        model = BatchNormNet()
        fsdp = FSDP(model, auto_wrap_policy=policy)
        # Batchnorms should be wrapped
        for layer in [fsdp.bn1, fsdp.bn2, fsdp.bn3, fsdp.sync_bn]:
            self.assertTrue(isinstance(layer, FSDP))

        self.assertFalse(isinstance(fsdp.lin, FSDP))
Exemplo n.º 15
0
 def test_auto_wrap_preset_force_leaf(self):
     """
     Test to ensure force-leaf modules are not wrapped, and children are not wrapped. The
     size_based_auto_wrap_policy forces leaf modules of type {nn.MultiheadAttention} to not be wrapped
     """
     sequential = nn.Sequential(nn.Linear(10, 10),
                                nn.MultiheadAttention(100, 1))
     my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                             min_num_params=40)
     model = FSDP(sequential,
                  process_group=self.process_group,
                  auto_wrap_policy=my_auto_wrap_policy)
     self.assertTrue(isinstance(model.module[0], FSDP))
     # Assert children of multihead attention are not wrapped
     self.assertTrue(isinstance(model.module[1], nn.MultiheadAttention))
     self.assertTrue(isinstance(model.module[1].out_proj, nn.Linear))
Exemplo n.º 16
0
    def test_auto_wrap_preset_exclude_wrap(self):
        """
        Test to ensure excluded modules are not wrapped, regardless if the total param size is greater than the
        min_num_params. the size_based_auto_wrap_policy excludes wrapping for {nn.ModuleList, nn.ModuleDict}
        """
        sequential = nn.ModuleList([nn.Linear(5, 5), nn.Linear(5, 5)])
        my_auto_wrap_policy = functools.partial(size_based_auto_wrap_policy,
                                                min_num_params=40)

        model = FSDP(sequential,
                     process_group=self.process_group,
                     auto_wrap_policy=my_auto_wrap_policy)

        self.assertTrue(isinstance(model, FSDP))
        self.assertTrue(isinstance(model[0], nn.Linear))
        self.assertTrue(isinstance(model[1], nn.Linear))
Exemplo n.º 17
0
 def test_auto_wrap_preset_force_leaf_custom(self):
     """
     Test to ensure force-leaf modules are not wrapped.
     """
     my_auto_wrap_policy = functools.partial(
         size_based_auto_wrap_policy,
         min_num_params=40,
         force_leaf_modules=size_based_auto_wrap_policy.FORCE_LEAF_MODULES.
         union({nn.Linear}),
     )
     sequential = nn.Sequential(nn.Linear(10, 10),
                                nn.ModuleList([nn.Linear(10, 10)]))
     model = FSDP(sequential,
                  process_group=self.process_group,
                  auto_wrap_policy=my_auto_wrap_policy)
     # Model was wrapped in FSDP as no inner modules were wrapped.
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[1], nn.ModuleList))
Exemplo n.º 18
0
 def test_always_wrap_with_ignored_modules(self, wrap_method: WrapMethod):
     sequential = TestFSDPWrap.NestedSequentialModel.get_model(cuda=False)
     ignored_modules = [sequential[1], sequential[2][0]]
     fsdp_kwargs = {
         "process_group": self.process_group,
         "auto_wrap_policy": always_wrap_policy,
         "ignored_modules": ignored_modules,
     }
     if wrap_method == WrapMethod.FSDP_CTOR:
         model = FSDP(sequential, **fsdp_kwargs)
     elif wrap_method == WrapMethod.WRAP_API:
         with enable_wrap(wrapper_cls=FSDP, **fsdp_kwargs):
             model = wrap(sequential)
     else:
         assert 0, f"Unsupported wrap method: {wrap_method}"
     # All non-ignored modules should be wrapped with FSDP
     self.assertTrue(isinstance(model, FSDP))
     self.assertTrue(isinstance(model.module[0], FSDP))
     self.assertTrue(isinstance(model.module[1], nn.Linear))
     self.assertTrue(isinstance(model.module[2], FSDP))
     self.assertTrue(isinstance(model.module[2].module[0], nn.Linear))
     self.assertTrue(isinstance(model.module[2].module[1], FSDP))
Exemplo n.º 19
0
    def test_basic_checkpoint_end_to_end(self, cpu_offload,
                                         offload_activations):
        seq = TestFSDPCheckpoint.SequentialModule().to(
            torch.cuda.current_device())
        # Runs FSDP with no checkpointing
        fsdp_only_seq = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # Runs checkpoint-wrapped FSDP
        checkpointed_fsdp = checkpoint_wrapper(
            FSDP(deepcopy(seq), cpu_offload=cpu_offload),
            offload_to_cpu=offload_activations,
        )
        # Runs FSDP-wrapped checkpointed module
        fsdp_wrapped_checkpoint = FSDP(
            checkpoint_wrapper(deepcopy(seq),
                               offload_to_cpu=offload_activations),
            cpu_offload=cpu_offload,
        )
        # Runs FSDP with manual calls to checkpoint.
        fsdp_call_checkpoint = FSDP(deepcopy(seq), cpu_offload=cpu_offload)
        # note that reentrant-based checkpointing requires inputs to have grad
        # flag set.

        inp = torch.randn(10,
                          3,
                          device=torch.cuda.current_device(),
                          requires_grad=True)

        models = [
            fsdp_only_seq,
            checkpointed_fsdp,
            fsdp_wrapped_checkpoint,
            fsdp_call_checkpoint,
        ]

        offload_to_cpu_event = "Memcpy DtoH" if torch.version.cuda else "CopyDeviceToHost"

        for i in range(6):
            losses = []
            outputs = []
            for m in models:
                check_offload = m != fsdp_only_seq and i == 0 and offload_activations
                profiler_ctx = (torch.profiler.profile(
                    use_cuda=True) if check_offload else contextlib.suppress())
                with profiler_ctx as prof:
                    if m == fsdp_call_checkpoint:
                        offload_ctx = (torch.autograd.graph.save_on_cpu(
                            pin_memory=True) if offload_activations else
                                       contextlib.suppress())
                        with offload_ctx:
                            out = checkpoint(m, inp)
                    else:
                        out = m(inp)

                if check_offload:
                    event_names = [event.name for event in prof.events()]
                    offload_occured = any(offload_to_cpu_event in name
                                          for name in event_names)
                    self.assertTrue(offload_occured)
                loss = out.sum()
                loss.backward()
                losses.append(loss)
                outputs.append(out)

            self._verify_parity(losses, outputs, models)
Exemplo n.º 20
0
    def test_main_wrap_api(self, cpu_offload, backward_prefetch,
                           forward_prefetch, cuda_init_mode):

        if cuda_init_mode == CUDAInitMode.CUDA_AFTER and cpu_offload.offload_params:
            # they don't work together, expected
            return

        move_to_cuda = cuda_init_mode == CUDAInitMode.CUDA_BEFORE

        class Nested(nn.Module):
            def __init__(self):
                super().__init__()
                self.nested_lin = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                              move_to_cuda)

            def forward(self, input):
                return self.nested_lin(input)

        class MyModel(nn.Module):
            def __init__(self):
                super().__init__()
                self.lin1 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin2 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin3 = _maybe_cuda(nn.Linear(1, 1, bias=False),
                                        move_to_cuda)
                self.lin4 = Nested()

            def forward(self, input):
                return self.lin4(self.lin3(self.lin2(self.lin1(input))))

        model = MyModel()
        wrapped_model = FSDP(
            model,
            auto_wrap_policy=functools.partial(
                size_based_auto_wrap_policy,
                min_num_params=0,  # wrap all modules
            ),
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            forward_prefetch=forward_prefetch,
        )
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            wrapped_model = wrapped_model.cuda()

        modules_in_fsdp_graph_order = [
            wrapped_model.module.lin1, wrapped_model.module.lin2,
            wrapped_model.module.lin3,
            wrapped_model.module.lin4.module.nested_lin,
            wrapped_model.module.lin4, wrapped_model
        ]

        for module in modules_in_fsdp_graph_order:
            self.assertTrue(isinstance(module, FSDP))
            self._check_cpu_offload(module, cpu_offload)
            self._check_backward_prefetch(module, backward_prefetch)
            self._check_forward_prefetch(module, forward_prefetch)

        # Run model a few times for sanity check.
        optim = torch.optim.SGD(wrapped_model.parameters(),
                                lr=1e-2,
                                momentum=0.9)
        inp = torch.ones(1).cuda()
        for _ in range(6):
            optim.zero_grad()
            loss = wrapped_model(inp).sum()
            loss.backward()
            optim.step()

        # Since we ran with backward prefetch, verify backward prefetch related
        # data.
        for i, module in enumerate(modules_in_fsdp_graph_order):
            self.assertEqual(i, module._my_fsdp_idx_in_graph)
            self.assertTrue(
                module._fsdp_graph_order == modules_in_fsdp_graph_order)