Exemple #1
0
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
Exemple #2
0
    def test_diff_ignored_modules_across_ranks(
            self, pass_ignored_modules_to_root: bool):
        """
        Tests ignoring different modules across ranks.

        Args:
            pass_ignored_modules_to_root (bool): If ``False``, does not pass
                any ignored modules (including those already ignored in child
                FSDP instances) to the root FSDP instance; if ``True``, passes
                all ignored modules (representing a superset of the children's
                ignored modules) to the root FSDP instance.
        """
        # To exercise different `FlatParameter` enumerations across ranks,
        # we wrap `layer3` with FSDP, where `layer3` is registered as a module
        # after `layer1`, which has the variable number of ignored modules
        model = ModelWithIgnoredModules(num_ignored=self.rank + 1).cuda()
        layer1_ignored_modules = [
            m for m in model.layer1.modules() if isinstance(m, IgnoredModule)
        ]
        model.layer1 = FSDP(model.layer1,
                            ignored_modules=layer1_ignored_modules)
        model.layer3 = FSDP(model.layer3)
        model_ignored_modules = [
            m for m in model.modules() if isinstance(m, IgnoredModule)
        ] if pass_ignored_modules_to_root else []
        wrapped_model = FSDP(model, ignored_modules=model_ignored_modules)
        optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
        self._train_model(wrapped_model, optim, 3)
Exemple #3
0
 def _get_wrapped_model(
     self,
     group,
     cuda_first=False,
     config=None,
     **model_kwargs,
 ) -> FullyShardedDataParallel:
     if config is None:
         config = {}
     move_to_cuda = not ("cpu_offload" in config
                         and config["cpu_offload"].offload_params)
     if cuda_first:
         transformer = TransformerWithSharedParams(group, **model_kwargs)
         if move_to_cuda:
             transformer = transformer.cuda()
         model = FullyShardedDataParallel(transformer, group, **config)
     else:
         model = FullyShardedDataParallel(
             TransformerWithSharedParams(group, **model_kwargs),
             group,
             **config,
         )
         if move_to_cuda:
             model = model.cuda()
     return model
Exemple #4
0
    def test_fsdp_same_model_across_ranks(self):
        """
        FSDP broadcasts model from rank 0 to ensure it starts off with the same
        values.
        """
        class MyModel(nn.Module):
            def __init__(self, rank):
                super().__init__()
                # Seed via rank to make model different across ranks
                torch.manual_seed(rank)
                torch.cuda.manual_seed(rank)
                self.lin = nn.Linear(10, 10, bias=False)
                self.register_buffer("buffer", torch.ones(1) * rank)

        m = MyModel(self.rank).cuda()
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)

        # sync_module_states also works with CPU module with device_id passed in
        m = MyModel(self.rank)
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
    def test_input_type(self, input_cls):
        """Test FSDP with input being a list or a dict, only single GPU."""
        class Model(Module):
            def __init__(self):
                super().__init__()
                self.layer = Linear(4, 4)

            def forward(self, input):
                if isinstance(input, list):
                    input = input[0]
                else:
                    assert isinstance(input, dict), input
                    input = input["in"]
                return self.layer(input)

        model = FSDP(Model()).cuda()
        optim = SGD(model.parameters(), lr=0.1)

        for _ in range(5):
            in_data = torch.rand(64, 4).cuda()
            in_data.requires_grad = True
            if input_cls is list:
                in_data = [in_data]
            else:
                self.assertTrue(input_cls is dict)
                in_data = {"in": in_data}

            out = model(in_data)
            out.sum().backward()
            optim.step()
            optim.zero_grad()
Exemple #6
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """Tests that passing a CPU module to FSDP preserves that the wrapped
     module is on CPU after FSDP initialization, albeit after loging a
     warning, and that FSDP moves CPU input to GPU before the forward."""
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(
         expected_warning=UserWarning, expected_regex=regex
     )
     with context:
         nested_wrapped_module = NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.RECURSIVE,
             CUDAInitMode.CUDA_NEVER,
         )
         fsdp_model = FSDP(nested_wrapped_module, self.process_group)
     devices = {p.device for p in fsdp_model.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp_model = fsdp_model.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = fsdp_model.module.get_input(device=torch.device("cpu"))
     fsdp_model(*inp).sum().backward()
Exemple #7
0
 def init(
     group: dist.ProcessGroup,
     fsdp_init_mode: FSDPInitMode,
     cuda_init_mode: CUDAInitMode,
     fsdp_kwargs: Optional[Dict[str, Any]] = None,
     deterministic: bool = False,
 ):
     """
     Initializes a :class:`NestedWrappedModule` instance, but unlike
     :meth:`NestedWrappedModule.init`, for the ``RECURSIVE`` init mode, this
     wraps with top-level FSDP and the ``always_wrap_policy()`` auto wrap
     policy.
     """
     super_ = super(AlwaysWrapNestedWrappedModule,
                    AlwaysWrapNestedWrappedModule)
     model = super_.init(
         group=group,
         fsdp_init_mode=FSDPInitMode.NO_FSDP,
         cuda_init_mode=cuda_init_mode,
         fsdp_kwargs=fsdp_kwargs,
         deterministic=deterministic,
     )
     if fsdp_init_mode == FSDPInitMode.NO_FSDP:
         return model
     elif fsdp_init_mode == FSDPInitMode.RECURSIVE:
         fsdp_model = FSDP(model,
                           auto_wrap_policy=always_wrap_policy,
                           **fsdp_kwargs)
         if cuda_init_mode == CUDAInitMode.CUDA_AFTER:
             fsdp_model = fsdp_model.cuda()
         return fsdp_model
    def test_summon_full_param_shard_value(self):

        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank))
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flatenned param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model._summon_full_params():
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            all_shards = next(model.parameters())
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
    def test_summon_full_param_writeback(
        self, writeback, cpu_offload, modify_outer
    ):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        # set the value
        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        p = outer_param if modify_outer else inner_param

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model._summon_full_params(writeback=writeback):
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        if writeback:
            self.assertEqual(p.cpu()[0], 0)
        else:
            self.assertEqual(p.cpu()[0], self.rank + 2)
Exemple #10
0
 def test_fsdp_cpu_init_stays_on_cpu(self):
     """
     Ensure that CPU model input stays on CPU
     after FSDP init and we log a warning.
     """
     torch.cuda.set_device(self.rank)
     regex = "Module is put on CPU"
     context = self.assertWarnsRegex(expected_warning=UserWarning,
                                     expected_regex=regex)
     with context:
         mod = NestedWrappedModule(
             group=self.process_group,
             wrap_fsdp=True,
             wrap_everything=True,
             fsdp_init_mode=FSDPInitMode.CUDA_NEVER,
         )
         fsdp = FSDP(mod)
     devices = {p.device for p in fsdp.parameters()}
     self.assertEqual(1, len(devices))
     self.assertEqual(torch.device("cpu"), devices.pop())
     fsdp = fsdp.cuda()
     # Ensure fwd + backward can be performed after moving to CUDA.
     # CPU input also tests that input is correctly moved to appropriate
     # CUDA device.
     inp = mod.get_input(device=torch.device("cpu"))
     fsdp(inp[0]).sum().backward()
    def test_summon_full_params_respects_reshard_after_forward(self):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))

        # the root FSDP module keeps all params around
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly _summon_full_params should have the same behavior
        with model._summon_full_params():
            pass
        self.assertEqual(
            outer_full_param_size, outer_param._full_param_padded.storage().size()
        )
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
Exemple #12
0
 def test_wrong_state_dict_config(self):
     model = FSDP(Model(wrap_fsdp=True).cuda())
     with self.assertRaisesRegex(RuntimeError,
                                 "Expected state_dict_config of type"):
         with model.state_dict_type(model, StateDictType.FULL_STATE_DICT,
                                    LocalStateDictConfig()):
             pass
Exemple #13
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
Exemple #14
0
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     wrapped_model = FSDP(
         model,
         self.process_group,
         ignored_modules=[model.transformer],
     )
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
    def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
Exemple #16
0
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank), mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size, self.get_model_param_count(model))

        # we're assuming a single flatenned param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParameter(parameters, requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0 : my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0 : my_slice.numel()].cpu(), my_slice.cpu())
            )
 def _test_mixed_precision_embedding_table(self, mp_config):
     # Basic test to ensure int inputs are not casted which would break
     # modules such as embedding tables.
     param_dtype = mp_config.param_dtype or torch.float32
     orig_reduce_scatter = dist._reduce_scatter_base
     test_reduce_scatter = partial(
         self._reduce_scatter_base_validate_mp,
         orig_reduce_scatter,
         mp_config,
     )
     with patch_reduce_scatter(test_reduce_scatter, param_dtype):
         # TODO: `test_mp_embedding_reduce()` fails if we do not wrap the
         # entire `TransformerWithSharedParams` with a single top-level FSDP
         model = TransformerWithSharedParams.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             {"mixed_precision": mp_config},
         )
         fsdp_model = FSDP(model, mixed_precision=mp_config)
         optim = torch.optim.SGD(fsdp_model.parameters(), lr=0.1)
         for _ in range(6):
             inp = fsdp_model.module.get_input(torch.device("cuda"))
             # This would fail if we casted integer module inputs such as for
             # embedding tables.
             output = fsdp_model(*inp)
             loss = fsdp_model.module.get_loss(inp, output).cuda()
             self.assertEqual(loss.dtype, param_dtype)
             fsdp_model.module.run_backward(loss)
             optim.step()
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload,
                                          modify_outer):
    model = FSDP(
        nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)),
                      nn.Linear(5, 3, bias=False))).cuda(cls.rank)

    # set the value
    outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
    inner_param = model.get_parameter(
        "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param")
    p = outer_param if modify_outer else inner_param

    with torch.no_grad():
        # This sets the local shard value
        p[0] = cls.rank + 2

    with model.summon_full_params(writeback=writeback):
        with torch.no_grad():
            p.copy_(torch.zeros_like(p))

    if writeback or cls.world_size == 1:
        # When world_size = 1, FSDP does not shard and parameter is not set to
        # a local shard, so write is always reflected.
        cls.assertEqual(p.cpu()[0], 0)
    else:
        cls.assertEqual(p.cpu()[0], cls.rank + 2)
Exemple #20
0
    def test_params_are_unflattenned(self, rank0_only, offload_to_cpu):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        fsdp_model = FSDP(deepcopy(model)).cuda(self.rank)

        def _get_flat_param():
            return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param")

        flattened_param = _get_flat_param()
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params(rank0_only=rank0_only,
                                           writeback=not rank0_only,
                                           offload_to_cpu=offload_to_cpu):
            if self.rank == 0 or not rank0_only:
                self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
                expected_device = (torch.device("cpu")
                                   if offload_to_cpu else torch.device(
                                       "cuda", torch.cuda.current_device()))
                self.assertTrue(expected_device == fsdp_model.weight.device)
            else:
                # Nonzero rank with rank0_only maintains original params.
                flat_within_ctx = _get_flat_param()
                self.assertEqual(flat_within_ctx, flattened_param)
                self.assertEqual(flat_within_ctx.device,
                                 torch.device(torch.cuda.current_device()))

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
    def test_summon_full_param_recursive(self, recurse, summon_outer):
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False)), nn.Linear(5, 3, bias=False)
            )
        ).cuda(self.rank)

        global_inner_numel = self.get_model_param_count(nn.Linear(5, 5, bias=False))
        global_outer_numel = self.get_model_param_count(nn.Linear(5, 3, bias=False))

        shard_inner_numel = int(math.ceil(global_inner_numel / self.world_size))
        shard_outer_numel = int(math.ceil(global_outer_numel / self.world_size))

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        self.assertEqual(shard_outer_numel, outer_param.numel())
        self.assertEqual(shard_inner_numel, inner_param.numel())

        model_to_summon = model if summon_outer else model[0]
        # outer is summoned if _summon_full_param is called on the outer FSDP module
        expected_outer_numel = global_outer_numel if summon_outer else shard_outer_numel

        # inner is summoned if _summon_full_param is called with recursion or on the inner FSDP module
        expected_inner_numel = (
            global_inner_numel if recurse or not summon_outer else shard_inner_numel
        )

        with model_to_summon._summon_full_params(recurse=recurse):
            self.assertEqual(expected_outer_numel, outer_param.numel())
            self.assertEqual(expected_inner_numel, inner_param.numel())
    def test_mixed_precision_resnet(self):
        """
        End to end test to ensure mixed precision + auto_wrap works
        for ResNet model.
        """
        resnet_model = torchvision.models.resnet50().cuda()
        resnet_model = nn.SyncBatchNorm.convert_sync_batchnorm(
            resnet_model,
            process_group=dist.distributed_c10d._get_default_group())
        n_bn = sum(1 if isinstance(x, _BatchNorm) else 0
                   for x in resnet_model.modules())
        inp = torch.ones(1, 3, 1000, 1000, device='cuda')
        mp_config = MixedPrecision(
            param_dtype=torch.float16,
            reduce_dtype=torch.float16,
            buffer_dtype=torch.float16,
        )
        fsdp = FSDP(resnet_model,
                    auto_wrap_policy=size_based_auto_wrap_policy,
                    mixed_precision=mp_config)
        # Batchnorm units should be wrapped individually. Validate this by
        # ensuring there are equal no. of FSDP units that are BN as BN units
        # in original resnet model.
        fsdp_bn = 0
        for module in fsdp.fsdp_modules(fsdp):
            wrapped_module = module.module.module
            if isinstance(wrapped_module, _BatchNorm):
                fsdp_bn += 1

        self.assertEqual(fsdp_bn, n_bn)
        # Would throw type mismatch issue without mixed precision autowrapping.
        loss = fsdp(inp).sum()
        loss.backward()
 def test_state_dict_type(self):
     module = SkipModel(double_nest=True)
     with enable_wrap(wrapper_cls=FSDP):
         fsdp = wrap(module)
     with FSDP.state_dict_type(fsdp, StateDictType.LOCAL_STATE_DICT):
         pass
     for module in FSDP.fsdp_modules(fsdp):
         self.assertEqual(module._state_dict_type, StateDictType.FULL_STATE_DICT)
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
    def test_optim_state_dict_nested(
        self,
        state_dict_type: StateDictType,
        use_multiple_param_groups: bool,
        rank0_only: bool,
        use_diff_optim_inputs: bool,
    ) -> None:
        """
        Tests :meth:`full_optim_state_dict` and `sharded_optim_state_dict`
        by comparing the returned dict for an FSDP-wrapped model with that of
        an equivalent non-wrapped model.

        The test checks the equivalence excluding the parameter keys since the
        FSDP and normal optimizer state dicts key by names and IDs,
        respectively. This means that the test can pass even if parameter keys
        are incorrectly mapped to values. Their correct mapping is tested in
        other tests that exercise the save/load workflow.
        """
        if rank0_only and state_dict_type == StateDictType.SHARDED_STATE_DICT:
            return  # not supported
        NUM_ITERS = 3
        model1, optim1, optim_input = self._init_nested_model(
            wrap=True, use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses1 = self._step_model(model1, optim1, num_iters=NUM_ITERS)
        if state_dict_type == StateDictType.FULL_STATE_DICT:
            fsdp_osd = FSDP.full_optim_state_dict(
                model1, optim1, optim_input, rank0_only=rank0_only,
            )
        else:
            fsdp_osd = FSDP.sharded_optim_state_dict(
                model1, optim1, optim_input
            )
        # Non-target ranks get an empty state dict
        if rank0_only and self.rank != 0:
            self.assertEqual(len(fsdp_osd), 0)
            return
        model2, optim2, _ = self._init_nested_model(
            wrap=False, use_multiple_param_groups=use_multiple_param_groups,
            use_diff_optim_inputs=use_diff_optim_inputs,
        )
        losses2 = self._step_model(model2, optim2, num_iters=NUM_ITERS)
        ref_osd = optim2.state_dict()
        # Check the losses to eliminate model drift as a source of error
        for i, (l1, l2) in enumerate(zip(losses1, losses2)):
            assert l1 == l2, f"Losses differ on iter {i}: {l1:.5f} {l2:.5f}"
        # Do not check the parameter keys since the full/sharded optimizer state
        # dict uses parameter names, while the non-wrapped equivalent uses
        # parameter IDs
        check_same_param_keys = False
        self._check_same_param_groups(
            fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys,
        )
        self._check_same_state(
            fsdp_osd, ref_osd, check_same_param_keys=check_same_param_keys,
        )
 def wrap(sharding_strategy: ShardingStrategy,
          device: torch.device,
          init_policy=always_wrap_policy):
     model = Model()
     wrap_policy = ParamExecOrderWrapPolicy(init_policy=init_policy)
     fsdp_model = FSDP(model,
                       auto_wrap_policy=wrap_policy,
                       sharding_strategy=sharding_strategy)
     return fsdp_model.to(device)
Exemple #27
0
 def test_rekey_optim_state_dict_to_names(
     self,
     use_multiple_param_groups: bool,
 ):
     """Tests :meth:`rekey_optim_state_dict` with the new keys being
     parameter names by checking that a non-wrapped model (i.e. without FSDP
     modules) can rekey its optimizer state dict to match the expected
     output of :meth:`full_optim_state_dict`, hence be sharded using
     :meth:`shard_full_optim_state_dict`, and finally match the per-rank
     optimizer state dict of a wrapped model (i.e. with FSDP modules)."""
     NUM_ITERS = 3
     # Run a wrapped model for a few iterations
     model1, optim1, optim_input1 = self._init_nested_model(
         wrap=True,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
     # Run a non-wrapped model for a few iterations
     model2, optim2, optim_input2 = self._init_nested_model(
         wrap=False,
         use_multiple_param_groups=use_multiple_param_groups,
     )
     self._step_model(model2, optim2, num_iters=NUM_ITERS)
     # Re-key the non-wrapped model's optimizer state dict using parameter
     # names (still according to itself)
     osd2 = optim2.state_dict()
     rekeyed_osd = FSDP.rekey_optim_state_dict(
         osd2,
         OptimStateKeyType.PARAM_NAME,
         model2,
         optim_input2,
     )
     # Shard the non-wrapped model's re-keyed optimizer state dict, which
     # maps back to (flattened) parameter IDs
     sharded_osd = FSDP.shard_full_optim_state_dict(
         rekeyed_osd,
         model1,
         optim_input1,
     )
     # Check that this sharded optimizer state dict matches the wrapped
     # model's per-rank optimizer state dict
     osd1 = optim1.state_dict()
     check_same_param_keys = True
     self._check_same_param_groups(
         sharded_osd,
         osd1,
         check_same_param_keys=check_same_param_keys,
     )
     self._check_same_state(
         sharded_osd,
         osd1,
         check_same_param_keys=check_same_param_keys,
     )
     # As a sanity check, check that we can load and run a few iterations
     optim1.load_state_dict(sharded_osd)
     self._step_model(model1, optim1, num_iters=NUM_ITERS)
 def wrap(
     sharding_strategy: ShardingStrategy,
     device: torch.device,
     wrap_policy: Callable,
 ) -> torch.nn.Module:
     model = Model()
     fsdp_model = FSDP(model,
                       auto_wrap_policy=wrap_policy,
                       sharding_strategy=sharding_strategy)
     return fsdp_model.to(device)
    def test_params_are_unflatenned(self):
        model = FSDP(nn.Linear(self.world_size, 1, bias=False)).cuda(self.rank)

        flattened_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, flattened_param.numel())

        with model._summon_full_params():
            a = model.weight.flatten().detach()
            b = flattened_param.detach()
            self.assertTrue(torch.equal(a, b))
 def __init__(self, wrap_fsdp, register_buffers=False):
     super().__init__()
     self.inner = Linear(*INNER_SHAPE)
     if register_buffers:
         self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
     if wrap_fsdp:
         self.inner = FSDP(self.inner)
     self.outer = Linear(*OUTER_SHAPE)
     if register_buffers:
         self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))