def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
    def test_summon_full_param_shard_value(self):

        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank))
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flatenned param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model._summon_full_params():
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            all_shards = next(model.parameters())
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank),
                     mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flattened param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParamHandle.flatten_params(parameters,
                                                        requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
 def test_ignored_modules_nested(self):
     """Tests that passing a module with nested FSDP modules does not
     error and still ignores non-FSDP modules' parameters."""
     # Initialize an FSDP-wrapped nested model that first wraps the nested
     # sequential's middle linear layer (`layer1[1]`) and then wraps the
     # overall model while ignoring the nested sequential (`layer1`)
     model = Model().cuda()
     model.layer1[1] = FSDP(model.layer1[1])
     wrapped_model = FSDP(model, ignored_modules=[model.layer1])
     # Check that the wrapped model's flattened parameter does not include
     # the ignored nested sequential's parameters
     nonwrapped_model = Model()
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.layer1.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     device = torch.device("cuda")
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     for _ in range(3):
         inp = wrapped_model.get_input(device)
         output = wrapped_model(*inp)
         loss = wrapped_model.get_loss(inp, output).to(device)
         wrapped_model.run_backward(loss)
         optim.step()
Exemple #6
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """Tests that passing a CPU module to FSDP preserves that the wrapped
     module is on CPU after FSDP initialization, albeit after loging a
     warning, and that FSDP moves CPU input to GPU before the forward."""
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(
         expected_warning=UserWarning, expected_regex=regex
     )
     with context:
         nested_wrapped_module = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_NEVER,
         )
         fsdp_model = FSDP(nested_wrapped_module, self.process_group)
     devices = {p.device for p in fsdp_model.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp_model = fsdp_model.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = fsdp_model.module.get_input(device=torch.device("cpu"))
     fsdp_model(*inp).sum().backward()
    def test_input_type(self, input_cls):
        """Test FSDP with input being a list or a dict, only single GPU."""
        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model()).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                self.assertTrue(input_cls is dict)
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()
Exemple #8
0
    def test_diff_ignored_modules_across_ranks(
            self, pass_ignored_modules_to_root: bool):
        """
        Tests ignoring different modules across ranks.

        Args:
            pass_ignored_modules_to_root (bool): If ``False``, does not pass
                any ignored modules (including those already ignored in child
                FSDP instances) to the root FSDP instance; if ``True``, passes
                all ignored modules (representing a superset of the children's
                ignored modules) to the root FSDP instance.
        """
        # To exercise different `FlatParameter` enumerations across ranks,
        # we wrap `layer3` with FSDP, where `layer3` is registered as a module
        # after `layer1`, which has the variable number of ignored modules
        model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
        layer1_ignored_modules = [
            m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
        ]
        model.layer1 = FSDP(model.layer1,
                            ignored_modules=layer1_ignored_modules)
        model.layer3 = FSDP(model.layer3)
        model_ignored_modules = [
            m for m in model.modules() if isinstance(m, IgnoredModule)
        ] if pass_ignored_modules_to_root else []
        wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
        self._train_model(wrapped_model, optim, 3)
Exemple #9
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """
     Ensure that CPU model input stays on CPU
     after FSDP init and we log a warning.
     """
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(expected_warning=UserWarning,
                                     expected_regex=regex)
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
         )
         fsdp = FSDP(mod)
     devices = {p.device for p in fsdp.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp = fsdp.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = mod.get_input(device=torch.device("cpu"))
     fsdp(inp[0]).sum().backward()
 def _test_mixed_precision_embedding_table(self, mp_config):
     # Basic test to ensure int inputs are not casted which would break
     # modules such as embedding tables.
     param_dtype = mp_config.param_dtype or torch.float32
     orig_reduce_scatter = dist._reduce_scatter_base
     test_reduce_scatter = partial(
         self._reduce_scatter_base_validate_mp,
         orig_reduce_scatter,
         mp_config,
     )
     with patch_reduce_scatter(test_reduce_scatter, param_dtype):
         # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the
         # entire `TransformerWithSharedParams` with a single top-level FSDP
         model = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             {"mixed_precision": mp_config},
         )
         fsdp_model = FSDP(model, mixed_precision=mp_config)
         optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)
         for _ in range(6):
             inp = fsdp_model.module.get_input(torch.device("cuda"))
             # This would fail if we casted integer module inputs such as for
             # embedding tables.
             output = fsdp_model(*inp)
             loss = fsdp_model.module.get_loss(inp, output).cuda()
             self.assertEqual(loss.dtype, param_dtype)
             fsdp_model.module.run_backward(loss)
             optim.step()
Exemple #11
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Exemple #12
0
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     wrapped_model = FSDP(
         model,
         self.process_group,
         ignored_modules=[model.transformer],
     )
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
Exemple #13
0
    def test_multiple_wrapping(self):
        """
        This test simulates wrapping the module after training to run inference.
        This is required in cases where later in a session, the model is wrapped again in FSDP but
        contains nested FSDP wrappers within the module.
        """
        inner_model = InnerModel()
        model = FSDP(inner_model).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for i in range(3):
            input = torch.rand((1, 5), dtype=torch.float).cuda()
            input.requires_grad = True
            output = model(input)
            output.sum().backward()
            optim.step()
            optim.zero_grad()
        input = torch.rand((1, 5), dtype=torch.float).cuda()
        output = model(input)

        # second time to rewrap the inner model
        rewrapped_model = FSDP(inner_model).cuda()
        rewrapped_output = rewrapped_model(input)

        self.assertEqual(output, rewrapped_output)
Exemple #14
0
    def test_params_are_unflattenned(self, rank0_only, offload_to_cpu):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        fsdp_model = FSDP(deepcopy(model)).cuda(self.rank)

        def _get_flat_param():
            return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param")

        flattened_param = _get_flat_param()
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params(rank0_only=rank0_only,
                                           writeback=not rank0_only,
                                           offload_to_cpu=offload_to_cpu):
            if self.rank == 0 or not rank0_only:
                self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
                expected_device = (torch.device("cpu")
                                   if offload_to_cpu else torch.device(
                                       "cuda", torch.cuda.current_device()))
                self.assertTrue(expected_device == fsdp_model.weight.device)
            else:
                # Nonzero rank with rank0_only maintains original params.
                flat_within_ctx = _get_flat_param()
                self.assertEqual(flat_within_ctx, flattened_param)
                self.assertEqual(flat_within_ctx.device,
                                 torch.device(torch.cuda.current_device()))

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
Exemple #15
0
    def _test_nested_model_with_meta_device(self, auto_wrap, meta_module_fn, init_fn=None):
        if auto_wrap:
            module = meta_module_fn()
            is_meta = next(module.parameters()).is_meta
            fsdp_meta = FSDP(
                module,
                auto_wrap_policy=always_wrap,
                param_init_fn=init_fn,
            )
            meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            fsdp_regular = FSDP(
                module_regular,
                auto_wrap_policy=always_wrap,
            )
            regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)
        else:
            with enable_wrap(
                wrapper_cls=FSDP, param_init_fn=init_fn,
            ):
                module = meta_module_fn()
                is_meta = next(module.parameters()).is_meta
                # Non FSDP modules will still be initialized because they bubble up
                # to be part of a larger FSDP unit.
                fsdp_meta = wrap(module)
                meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)

            # Init and reset parameters before wrapping so that reset_params
            # matches up with meta device's initialization.
            module_regular = NestedModel(device="cuda")
            _reset_params_if_meta(is_meta, module_regular)
            with enable_wrap(wrapper_cls=FSDP):
                module_regular.lin1 = wrap(module_regular.lin1)
                module_regular.l3 = wrap(module_regular.l3)
                fsdp_regular = wrap(module_regular)
                regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        # Compare it before training
        self._compare_fsdp(fsdp_meta, fsdp_regular)
        inp = torch.randn(10, 2, device='cuda')
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)
Exemple #16
0
    def _test_simple_model_with_meta_device(self, meta_module_fn, init_fn=None):
        # Create model on meta device and wrap with FSDP.
        model = meta_module_fn()
        is_meta = next(model.parameters()).is_meta
        fsdp_meta = FSDP(
            model,
            auto_wrap_policy=always_wrap,
            param_init_fn=init_fn,
        )

        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)

        # Test to make sure it is the same model parameters as regular FSDP
        # approach.
        regular = MyModel(device="cuda")
        _reset_params_if_meta(is_meta, regular)
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        self._compare_fsdp(fsdp_meta, fsdp_regular)
        inp = torch.randn(10, 2, device='cuda')
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)

        # Test that meta init works if all submodules are contained in only a
        # single FSDP unit.
        model = meta_module_fn()
        fsdp_meta = FSDP(model, param_init_fn=init_fn)
        meta_opt = torch.optim.SGD(fsdp_meta.parameters(), lr=1e-3)
        regular = MyModel(device="cuda")
        _reset_params_if_meta(is_meta, regular)
        fsdp_regular = FSDP(regular, auto_wrap_policy=always_wrap)
        regular_opt = torch.optim.SGD(fsdp_regular.parameters(), lr=1e-3)

        # Run a forward + backward pass + optimizer step
        fsdp_meta(inp).sum().backward()
        fsdp_regular(inp).sum().backward()
        meta_opt.step()
        regular_opt.step()
        self._compare_fsdp(fsdp_meta, fsdp_regular)
Exemple #17
0
    def _dist_train(self, with_nested_trunk, freezing_method,
                    freeze_after_wrap_fsdp, with_fsdp):
        torch.manual_seed(0)
        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = self._create_model(with_fsdp, with_nested_trunk,
                                   freeze_after_wrap_fsdp)
        model = model.cuda()

        # freezing the trunk using requires_grad.
        if freezing_method == FreezingMethod.RequiresGrad:
            for param in model.trunk.parameters():
                param.requires_grad = False

        if with_fsdp:
            if not freeze_after_wrap_fsdp:
                model.fsdp_wrap()
            model = FSDP(model)
        else:
            model = DistributedDataParallel(model, device_ids=[self.rank])

        target = torch.tensor([0, 1], dtype=torch.long).cuda()
        criterion = nn.CrossEntropyLoss()
        optimizer = optim.SGD(model.parameters(), lr=0.1, momentum=0.9)

        for iteration in range(3):
            out = model(batch)
            fake_loss = criterion(out, target)
            optimizer.zero_grad()
            fake_loss.backward()
            if freezing_method == FreezingMethod.GradToNone:
                if with_fsdp:
                    for param in model.module.module.trunk.parameters():
                        param.grad = None
                else:
                    for param in model.module.trunk.parameters():
                        param.grad = None
            optimizer.step()

        if with_fsdp:
            get_full_params(model)

        return list(model.parameters())
Exemple #18
0
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(
            DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload
        )
        local_model = DeterministicModel(wrap_fsdp=False)

        params_to_compare = (
            [p.clone() for p in model.parameters()]
            if rank0_only and self.rank != 0
            else list(local_model.parameters())
        )

        writeback = not rank0_only

        with model.summon_full_params(
            model,
            recurse=True,
            rank0_only=rank0_only,
            writeback=writeback,
            offload_to_cpu=offload_to_cpu,
        ):
            if writeback:
                with torch.no_grad():
                    for p in model.parameters():
                        p.add_(1)
                    for p in params_to_compare:
                        p.add_(1)
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)

        # CPU offload is enabled for main API, so we should point back to CPU
        for param in model.parameters():
            self.assertEqual(param.device, torch.device("cpu"))
    def test_summon_full_params_equivalence(self):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        with model.summon_full_params(recurse=True):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            fsdp_params = deepcopy(list(model.parameters()))

        self.assertEqual(fsdp_params, list(local_model.parameters()))
    def _dist_train(self, with_checkpoint, expected, model_hidden_dim, iterations):
        gpu_id = self.rank
        world_size = self.world_size

        batch = torch.randn(size=(2, 3, 224, 224)).cuda()

        model = create_model(
            with_fsdp=True,
            with_checkpoint=with_checkpoint,
            model_hidden_dim=model_hidden_dim,
        )
        model = model.cuda()
        model = FSDP(model)

        # We enable momentum so that after the first iteration, the optimizer state is added
        # to the total memory used.
        criterion = nn.MSELoss()
        optimizer = optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)

        results = {}  # results of memory stats
        for iteration in range(iterations):
            get_cur_mem(gpu_id, results, f"iter {iteration}: start")

            out = model(batch)
            get_cur_mem(gpu_id, results, f"iter {iteration}: after fwd")

            out = sum(o.sum() for o in out[0])
            fake_loss = criterion(out, torch.tensor(0.0).cuda())
            get_cur_mem(gpu_id, results, f"iter {iteration}: after loss")

            fake_loss.backward()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after bwd")

            optimizer.step()
            get_cur_mem(gpu_id, results, f"iter {iteration}: after step")

            # It is important to use `set_to_none` below, not optimizer.zero_grad() to reclaim memory.
            model.zero_grad(set_to_none=True)
            get_cur_mem(gpu_id, results, f"iter {iteration}: done")

        def cmp(results, expected):
            ret = ""
            self.assertEqual(results.keys(), expected.keys())
            for k, v in results.items():
                exp = expected[k]
                if abs(exp - v) > 1:  # allow 1MB rounding differences
                    ret += f"{k}: got {v}, expected {exp}\n"
            return ret

        output = cmp(results, expected)
        self.assertEqual(output, "")
 def test_params_count_and_value(self):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         ))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     with fsdp_model.summon_full_params():
         for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                             model.module.parameters()):
             self.assertEqual(p1, p2)
 def test_state_dict_with_ignored_modules(self):
     # Initialize an FSDP-wrapped model with an ignored module that includes
     # both parameters and a buffer
     model = Model(wrap_fsdp=True, register_buffers=True).cuda()
     ignored_modules = [model.outer]
     ignored_tensor_to_tensor_name = {
         model.outer.bias: "outer.bias",
         model.outer.weight: "outer.weight",
         model.outer.buffer: "outer.buffer",
     }
     buffer_to_buffer_name = {
         model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer",
     }
     fsdp_model = FSDP(model, ignored_modules=ignored_modules)
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd1 = fsdp_model.state_dict()
     with FSDP.summon_full_params(fsdp_model):
         fsdp_params = deepcopy(list(fsdp_model.parameters()))
     # Check that the ignored parameters and all buffers are not cloned
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)
         self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr())
     # Check that the state dict can be loaded into a non-wrapped version of
     # the model
     nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
     for param in nonwrapped_model.parameters():
         with torch.no_grad():
             param.zero_()
     nonwrapped_model.load_state_dict(sd1)
     local_params = list(nonwrapped_model.parameters())
     for fsdp_param, local_param in zip(fsdp_params, local_params):
         self.assertEqual(fsdp_param, local_param)
     # Check that if we save a state dict again, the ignored parameters and
     # buffer still have the same data pointer
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd2 = fsdp_model.state_dict()
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)  # check again just in case
         self.assertTrue(tensor_name in sd2)
         self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr())
         self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
Exemple #23
0
 def test_ignored_modules_nested(self):
     """Tests that passing a module with nested FSDP modules does not
     error and still ignores non-FSDP modules' parameters."""
     # Initialize an FSDP-wrapped nested model that first wraps the nested
     # sequential's second linear layer (`layer1[1]`) and then wraps the
     # overall model while ignoring the nested sequential (`layer1`)
     model = Model().cuda()
     model.layer1[1] = FSDP(model.layer1[1])
     wrapped_model = FSDP(model, ignored_modules=[model.layer1])
     # Check that the wrapped model's flattened parameter does not include
     # the ignored nested sequential's parameters
     nonwrapped_model = Model()
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.layer1.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
    def test_state_dict_with_ignored_modules(self, prefix, ignore_inner):
        # Initialize an FSDP-wrapped model with an ignored module that includes
        # both parameters and a buffer
        model = Model(wrap_fsdp=True,
                      register_buffers=True,
                      ignore_inner=ignore_inner).cuda()
        ignored_modules = [model.outer]
        ignored_tensor_to_tensor_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        if ignore_inner:
            ignored_tensor_to_tensor_name = {
                **ignored_tensor_to_tensor_name,
                model.inner.bias: "inner.bias",
                model.inner.weight: "inner.weight",
            }
        # Note that when model.inner is not ignored this test also ensures
        # non-ignored buffers are not cloned.
        buffer_to_buffer_name = {
            model.inner.buffer: "inner.buffer",
            model.outer.buffer: "outer.buffer",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        prefix_str = "foo." if prefix else ""
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd1 = fsdp_model.state_dict(prefix=prefix_str)
        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters and all buffers are not cloned
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd1)
            self.assertEqual(tensor.data_ptr(),
                             sd1[prefixed_tensor_name].data_ptr(),
                             f"{prefixed_tensor_name}")
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        to_load = {k[len(prefix_str):]: v for k, v in sd1.items()}
        nonwrapped_model.load_state_dict(to_load, strict=True)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
        # Check that if we save a state dict again, the ignored parameters and
        # buffer still have the same data pointer
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd2 = fsdp_model.state_dict(prefix=prefix_str)
        for tensor, tensor_name in {
                **ignored_tensor_to_tensor_name,
                **buffer_to_buffer_name,
        }.items():
            prefixed_tensor_name = f"{prefix_str}{tensor_name}"
            self.assertTrue(prefixed_tensor_name in sd2)
            self.assertEqual(tensor.data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
            self.assertEqual(sd1[prefixed_tensor_name].data_ptr(),
                             sd2[prefixed_tensor_name].data_ptr())
 def test_state_dict_rank0_offload_save_load_flow(self):
     """Tests saving a model checkpoint only on rank 0 and loading it only
     on rank 0 with ``sync_module_states=True`` to emulate the workflow to
     avoid redundant CPU memory usage."""
     auto_wrap_policy = partial(
         transformer_auto_wrap_policy,
         transformer_layer_cls={
             TransformerEncoderLayer, TransformerDecoderLayer
         },
     )
     fsdp_kwargs = {"auto_wrap_policy": auto_wrap_policy}
     fsdp_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
         fsdp_kwargs,
     )
     # Force model parameters and buffers to be nonzero
     with FSDP.summon_full_params(fsdp_model):
         for tensor in itertools.chain(fsdp_model.parameters(),
                                       fsdp_model.buffers()):
             if torch.count_nonzero(tensor) == 0:
                 with torch.no_grad():
                     tensor.add_(
                         torch.tensor(1,
                                      dtype=tensor.dtype,
                                      device=tensor.device))
     with self._get_state_dict_mgr(fsdp_model, "state_dict", True):
         state_dict = deepcopy(_get_state_dict(fsdp_model))
     # Initialize a non-wrapped model on all ranks
     new_model = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
     )
     _zero_model(new_model, zero_buffers=True)
     # Only load the checkpoint on rank 0
     if self.rank == 0:
         new_model.load_state_dict(state_dict, strict=True)
     _assert_module_states(
         new_model,
         process_group=self.process_group,
         assert_fn=self.assertNotEqual,
     )
     # Broadcast the module states from rank 0 with `sync_module_states=True`
     new_fsdp_model = FSDP(
         new_model,
         device_id=torch.cuda.current_device(),
         auto_wrap_policy=auto_wrap_policy,
         sync_module_states=True,
     )
     # Check FSDP models are equal across ranks
     with FSDP.summon_full_params(new_fsdp_model):
         _assert_module_states(
             new_fsdp_model,
             process_group=self.process_group,
             assert_fn=self.assertEqual,
         )
     # Check FSDP models correctly loaded the checkpoint
     with FullyShardedDataParallel.summon_full_params(fsdp_model):
         with FullyShardedDataParallel.summon_full_params(new_fsdp_model):
             params = list(fsdp_model.parameters())
             params_new = list(new_fsdp_model.parameters())
             self.assertEqual(params, params_new)
Exemple #26
0
    def _test_fsdp_parity(
        self,
        model_class: Type[FSDPTestModel],
        fsdp_init_mode: FSDPInitMode,
        cuda_init_mode: CUDAInitMode,
        ref_init_fn: Optional[Callable] = None,
        num_iters: int = 2,
        save_model: bool = True,
        cpu_offload: CPUOffload = CPUOffload(),
        backward_prefetch: Optional[BackwardPrefetch] = None,
        forward_prefetch: bool = False,
        sharding_strategy: Optional[ShardingStrategy] = None,
        mixed_precision: Optional[MixedPrecision] = None,
        enable_sharded_grad_scaler: bool = False,
        use_pure_fp16: bool = False,
        norm_type: Optional[Union[float, int]] = None,
        init_kwargs: Optional[Dict[str, Any]] = None,
        **fsdp_kwargs,
    ):
        """
        Tests FSDP training against a reference, which defaults to DDP but
        may be customized with ``ref_init_fn``.

        Args:
            model_class (Type[FSDPTestModel]): A model class that inherits from
                ``FSDPTestModel``, which defines the expected interface.
            fsdp_init_mode (FSDPInitMode): The mode to initialize the
                FSDP-wrapped model. This should not be ``NO_FSDP``.
            ref_init_fn (Optional[Callable]): A callable to invoke that wraps a
                non-wrapped model to construct the reference model, where this
                wrapper should provide data parallel semantics. If ``None``,
                then the callable defaults to the DDP constructor.
        """
        assert fsdp_init_mode != FSDPInitMode.NO_FSDP, "Expects an FSDP init mode that wraps with FSDP"
        if init_kwargs is None:
            init_kwargs = {}
        lr = 1e-2
        rank = self.process_group.rank()
        # Establish reference behavior with DDP
        model = model_class.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
            **init_kwargs,
        )
        if ref_init_fn is None:
            ref_model = DDP(model, device_ids=[rank], output_device=rank)
        else:
            ref_model = ref_init_fn(model)
        if use_pure_fp16:
            ref_model = ref_model.half()
        ref_loss = self._train_for_several_steps(
            ref_model,
            num_iters,
            autocast=mixed_precision is not None,
            lr=lr,
            fsdp_cpu_offload=cpu_offload,
            mixed_precision=mixed_precision,
            norm_type=norm_type,
            enable_sharded_grad_scaler=enable_sharded_grad_scaler,
            use_pure_fp16=use_pure_fp16,
        )
        ddp_params = list(ref_model.parameters())
        # Check against FSDP behavior
        fsdp_kwargs.update({
            "cpu_offload": cpu_offload,
            "backward_prefetch": backward_prefetch,
            "forward_prefetch": forward_prefetch,
            "sharding_strategy": sharding_strategy,
            "mixed_precision": mixed_precision,
        })
        try:
            fsdp_model = model_class.init(
                self.process_group,
                fsdp_init_mode,
                cuda_init_mode,
                fsdp_kwargs,
                deterministic=True,
                **init_kwargs,
            )
        except Exception as e:
            raise ValueError(
                f"Initializing {model_class} raised error {str(e)}")
        if not isinstance(fsdp_model, FSDP):
            # Enforce that we wrap with top-level FSDP since we are comparing
            # assuming a data parallel reference and some test models may not
            # do so in their `init()` method
            fsdp_model = FSDP(fsdp_model, self.process_group, **fsdp_kwargs)
        if use_pure_fp16:
            # Change the model parameter dtype after FSDP initialization
            fsdp_model = fsdp_model.half()
        if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
            fsdp_model = fsdp_model.cuda()
        offload_params = cpu_offload is not None and cpu_offload.offload_params
        # Offloading parameters with `CUDA_AFTER` should raise an error during
        # lazy initialization due to the parameter devices not being CPU;
        # otherwise, all parameter devices should be CPU
        expects_device_error = offload_params and cuda_init_mode == CUDAInitMode.CUDA_AFTER
        expects_cpu_device = offload_params and cuda_init_mode != CUDAInitMode.CUDA_AFTER
        if expects_cpu_device:
            cpu_device = torch.device("cpu")
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
        context = (self.assertRaisesRegex(AssertionError,
                                          "Expected param to be on CPU")
                   if expects_device_error else suppress())
        with context:
            fsdp_loss = self._train_for_several_steps(
                fsdp_model,
                num_iters,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
                mixed_precision=mixed_precision,
                norm_type=norm_type,
                enable_sharded_grad_scaler=enable_sharded_grad_scaler,
                use_pure_fp16=use_pure_fp16,
            )
        # No need to check for parameter and loss parity if expecting an error
        if expects_device_error:
            return
        # Check parameter devices are CPU if offloading to CPU before calling
        # `get_full_params()`, which will cast the parameters to FP32
        if offload_params:
            for param in fsdp_model.parameters():
                self.assertEqual(param.device, cpu_device)
            fsdp_loss = fsdp_loss.cuda()
        fsdp_unsharded_params = get_full_params(fsdp_model)
        torch.testing.assert_allclose(ref_loss, fsdp_loss)
        # Do not check for parameter parity if using mixed precision since (1)
        # the DDP parameters are in FP16 (from `half()`) while the FSDP
        # parameters are in FP32 (from `summon_full_params()`) and (2) DDP runs
        # the optimizer in FP16 while FSDP runs it in FP32
        if mixed_precision is not None:
            self.assertEqual(
                ddp_params,
                fsdp_unsharded_params,
                exact_device=True,
                msg="FSDP did not match DDP",
            )
Exemple #27
0
def _zero_model(fsdp_model: FullyShardedDataParallel):
    with fsdp_model.summon_full_params():
        for param in fsdp_model.parameters():
            with torch.no_grad():
                param.zero_()
Exemple #28
0
    def _test_identical_outputs(self,
                                model_init_fn,
                                *args,
                                ref_ddp_fn=None,
                                num_steps=2,
                                fsdp_init_mode=FSDPInitMode.CUDA_AFTER,
                                lr=0.01,
                                cpu_offload=CPUOffload(),
                                backward_prefetch=None,
                                sharding_strategy=None,
                                save_model=True,
                                clip_norm=0.3,
                                norm_type=None,
                                **kwargs):
        group = dist.distributed_c10d._get_default_group()
        rank = group.rank()
        # Establish reference behavior with PyTorch DDP (+ optionally autocast).
        model = model_init_fn(group=group, wrap_fsdp=False).cuda()
        if ref_ddp_fn is None:
            model = nn.parallel.DistributedDataParallel(model,
                                                        device_ids=[rank],
                                                        output_device=rank)
        else:
            model = ref_ddp_fn(model)

        # DDP training
        ref_loss = self._train_for_several_steps(model,
                                                 num_steps,
                                                 autocast=False,
                                                 lr=lr,
                                                 fsdp_cpu_offload=cpu_offload)
        ref_full_params = list(model.parameters())

        # Confirm we get the same behavior using FullyShardedDataParallel.
        try:
            model = model_init_fn(
                group=group,
                wrap_fsdp=True,
                fsdp_init_mode=fsdp_init_mode,
                cpu_offload=cpu_offload,
                backward_prefetch=backward_prefetch,
                sharding_strategy=sharding_strategy,
            )
        except Exception as e:
            raise ValueError(
                f"model_Init_fn {model_init_fn} got error {str(e)}")

        cpu_offload = cpu_offload or CPUOffload()  # disabled if not specified.
        model = FullyShardedDataParallel(
            model,
            cpu_offload=cpu_offload,
            backward_prefetch=backward_prefetch,
            sharding_strategy=sharding_strategy,
        )
        # Call model.cuda() after init FSDP if specified.
        if fsdp_init_mode == FSDPInitMode.CUDA_AFTER:
            model = model.cuda()

        # Note that we don't do this check for FSDPInitMode.CUDA_AFTER since we
        # expect FSDP code to raise error that we check below, in the case of
        # offload params.
        if fsdp_init_mode != FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params:
            for p in model.parameters():
                # Should be on CPU regardless of if param is sharded.
                self.assertEqual(p.device, torch.device("cpu"),
                                 f"Mismatch, cpu offload is {cpu_offload}")

        only_check_err = fsdp_init_mode == FSDPInitMode.CUDA_AFTER and cpu_offload.offload_params
        ctx = (self.assertRaisesRegex(AssertionError,
                                      "Expected param to be on CPU")
               if only_check_err else suppress())
        with ctx:
            # FSDP training
            shard_loss = self._train_for_several_steps(
                model,
                num_steps,
                autocast=False,
                lr=lr,
                fsdp_cpu_offload=cpu_offload,
                save_model=save_model,
            )
        # We only check for errors in the case we have the following setup:
        # model = FSDP(model, cpu_offload=True)
        # model = model.cuda()
        # so skip the rest of this logic.
        if only_check_err:
            return
        # If CPU offload, next call will change model params to GPU. Sanity
        # check that params are on CPU before.
        if cpu_offload.offload_params:
            device_set = {p.device for p in model.parameters()}
            self.assertEqual({torch.device("cpu")}, device_set,
                             f"Got device set {device_set}")
        shard_full_params = get_full_params(model)

        if cpu_offload.offload_params:
            shard_loss = shard_loss.cuda()
        torch.testing.assert_allclose(ref_loss, shard_loss)
        self.assertEqual(
            ref_full_params,
            shard_full_params,
            exact_device=True,
            msg="FullyShardedDataParallel didn't match PyTorch DDP",
        )
Exemple #29
0
def _get_full_detached_param(fsdp_model: FullyShardedDataParallel):
    with fsdp_model.summon_full_params():
        params = list(p.clone().detach_() for p in fsdp_model.parameters())

    return params