Exemple #1
0
    def test_fsdp_same_model_across_ranks(self):
        """
        FSDP broadcasts model from rank 0 to ensure it starts off with the same
        values.
        """
        class MyModel(nn.Module):
            def __init__(self, rank):
                super().__init__()
                # Seed via rank to make model different across ranks
                torch.manual_seed(rank)
                torch.cuda.manual_seed(rank)
                self.lin = nn.Linear(10, 10, bias=False)
                self.register_buffer("buffer", torch.ones(1) * rank)

        m = MyModel(self.rank).cuda()
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)

        # sync_module_states also works with CPU module with device_id passed in
        m = MyModel(self.rank)
        _assert_module_states(m, process_group=self.process_group, assert_fn=self.assertNotEqual)
        # Passing sync_module_states into FSDP makes model the same during init.
        fsdp = FSDP(m, device_id=torch.cuda.current_device(), sync_module_states=True)
        with fsdp.summon_full_params(fsdp):
            _assert_module_states(fsdp, process_group=self.process_group, assert_fn=self.assertEqual)
    def test_reshard_outside_forward_backward_iteration(
            self, rank0_only, offload_to_cpu, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 1, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # First lets validate our assumption about resharding

        output = model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        # we reshard everything after backward() finishes
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # now lets repeat it with summon done in between

        output = model(torch.zeros(5).cuda(self.rank))
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        output.backward()
        with model.summon_full_params(
                model,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            pass
        self.assertEqual(0, outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
Exemple #4
0
    def test_params_are_unflattenned(self, rank0_only, offload_to_cpu):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        fsdp_model = FSDP(deepcopy(model)).cuda(self.rank)

        def _get_flat_param():
            return fsdp_model.get_parameter("_fsdp_wrapped_module.flat_param")

        flattened_param = _get_flat_param()
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params(rank0_only=rank0_only,
                                           writeback=not rank0_only,
                                           offload_to_cpu=offload_to_cpu):
            if self.rank == 0 or not rank0_only:
                self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
                expected_device = (torch.device("cpu")
                                   if offload_to_cpu else torch.device(
                                       "cuda", torch.cuda.current_device()))
                self.assertTrue(expected_device == fsdp_model.weight.device)
            else:
                # Nonzero rank with rank0_only maintains original params.
                flat_within_ctx = _get_flat_param()
                self.assertEqual(flat_within_ctx, flattened_param)
                self.assertEqual(flat_within_ctx.device,
                                 torch.device(torch.cuda.current_device()))

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([
            p.clone() for p in model.parameters()
        ] if rank0_only and self.rank != 0 else list(local_model.parameters()))

        with model.summon_full_params(
                model,
                recurse=True,
                rank0_only=rank0_only,
                writeback=not rank0_only,
                offload_to_cpu=offload_to_cpu,
        ):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)
def _run_test_summon_full_param_writeback(cls, writeback, cpu_offload,
                                          modify_outer):
    model = FSDP(
        nn.Sequential(FSDP(nn.Linear(5, 5, bias=False)),
                      nn.Linear(5, 3, bias=False))).cuda(cls.rank)

    # set the value
    outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
    inner_param = model.get_parameter(
        "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param")
    p = outer_param if modify_outer else inner_param

    with torch.no_grad():
        # This sets the local shard value
        p[0] = cls.rank + 2

    with model.summon_full_params(writeback=writeback):
        with torch.no_grad():
            p.copy_(torch.zeros_like(p))

    if writeback or cls.world_size == 1:
        # When world_size = 1, FSDP does not shard and parameter is not set to
        # a local shard, so write is always reflected.
        cls.assertEqual(p.cpu()[0], 0)
    else:
        cls.assertEqual(p.cpu()[0], cls.rank + 2)
    def test_summon_full_params_respects_reshard_after_forward(
            self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = FSDP(
            nn.Sequential(
                FSDP(nn.Linear(5, 5, bias=False),
                     mixed_precision=mixed_precision),
                nn.Linear(5, 3, bias=False),
            ),
            mixed_precision=mixed_precision,
        ).cuda(self.rank)

        outer_param = model.get_parameter("_fsdp_wrapped_module.flat_param")
        inner_param = model.get_parameter(
            "_fsdp_wrapped_module._fpw_module.0._fsdp_wrapped_module.flat_param"
        )
        outer_full_param_size = outer_param.numel() * self.world_size

        # trigger lazy init
        model(torch.zeros(5).cuda(self.rank))
        # the root FSDP module keeps all params around
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())

        # similarly summon_full_params should have the same behavior
        with model.summon_full_params(model):
            pass
        self.assertEqual(outer_full_param_size,
                         outer_param._full_param_padded.storage().size())
        self.assertEqual(0, inner_param._full_param_padded.storage().size())
 def test_ignored_modules_nested(self):
     """Tests that passing a module with nested FSDP modules does not
     error and still ignores non-FSDP modules' parameters."""
     # Initialize an FSDP-wrapped nested model that first wraps the nested
     # sequential's middle linear layer (`layer1[1]`) and then wraps the
     # overall model while ignoring the nested sequential (`layer1`)
     model = Model().cuda()
     model.layer1[1] = FSDP(model.layer1[1])
     wrapped_model = FSDP(model, ignored_modules=[model.layer1])
     # Check that the wrapped model's flattened parameter does not include
     # the ignored nested sequential's parameters
     nonwrapped_model = Model()
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.layer1.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     device = torch.device("cuda")
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     for _ in range(3):
         inp = wrapped_model.get_input(device)
         output = wrapped_model(*inp)
         loss = wrapped_model.get_loss(inp, output).to(device)
         wrapped_model.run_backward(loss)
         optim.step()
    def test_params_count_and_value(
        self,
        rank0_only: bool,
        offload_to_cpu: bool,
        mixed_precision: bool,
    ):
        mixed_precision = MixedPrecision() if mixed_precision else None
        model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.NO_FSDP,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        fsdp_model = NestedWrappedModule.init(
            self.process_group,
            FSDPInitMode.RECURSIVE,
            CUDAInitMode.CUDA_BEFORE,
            deterministic=True,
        )
        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))
        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with FSDP.summon_full_params(fsdp_model,
                                     rank0_only=rank0_only,
                                     writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     group = dist.distributed_c10d._get_default_group()
     wrapped_model = self._get_wrapped_model(group, ignore_modules=True)
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model = self._get_nonwrapped_model(group)
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     device = torch.device("cuda")
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     for _ in range(3):
         inp = wrapped_model.module.get_input(device)
         output = wrapped_model(*inp)
         loss = wrapped_model.module.get_loss(inp, output).to(device)
         wrapped_model.module.run_backward(loss)
         optim.step()
Exemple #11
0
    def test_state_dict_with_ignored_modules(self):
        # Initialize an FSDP-wrapped model with an ignored module
        model = Model(wrap_fsdp=True).cuda()
        ignored_modules = [model.outer]
        ignored_param_to_param_name = {
            model.outer.bias: "outer.bias",
            model.outer.weight: "outer.weight",
        }
        fsdp_model = FSDP(model, ignored_modules=ignored_modules)
        with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
            sd = fsdp_model.state_dict()

        with FSDP.summon_full_params(fsdp_model):
            fsdp_params = deepcopy(list(fsdp_model.parameters()))
        # Check that the ignored parameters are not cloned

        for param, param_name in ignored_param_to_param_name.items():
            self.assertTrue(param_name in sd)
            self.assertEqual(param.data_ptr(), sd[param_name].data_ptr())
        # Check that the state dict can be loaded into a non-wrapped version of
        # the model
        nonwrapped_model = Model(wrap_fsdp=False).cuda()
        for param in nonwrapped_model.parameters():
            with torch.no_grad():
                param.zero_()

        nonwrapped_model.load_state_dict(sd)
        local_params = list(nonwrapped_model.parameters())
        for fsdp_param, local_param in zip(fsdp_params, local_params):
            self.assertEqual(fsdp_param, local_param)
    def test_summon_full_param_shard_value(self, mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        raw_model = nn.Linear(10, 11)
        raw_model_size = self.get_model_param_count(raw_model)
        expected_shard_size = self.get_expected_sharded_size(raw_model_size)

        model = FSDP(raw_model.cuda(self.rank),
                     mixed_precision=mixed_precision)
        self.assertEqual(expected_shard_size,
                         self.get_model_param_count(model))

        # we're assuming a single flattened param
        self.assertEqual(1, len(list(model.parameters())))

        my_shard = torch.clone(next(model.parameters()))

        with model.summon_full_params(model):
            self.assertEqual(raw_model_size, self.get_model_param_count(model))
            parameters = list(model.parameters())
            all_shards = FlatParamHandle.flatten_params(parameters,
                                                        requires_grad=False)
            my_slice = torch.chunk(all_shards, self.world_size)[self.rank]

            # shards are padded but the full_param tensor is not
            a, b = my_shard[0:my_slice.numel()], my_slice
            self.assertTrue(
                torch.equal(my_shard[0:my_slice.numel()].cpu(),
                            my_slice.cpu()))
    def test_params_count_and_value(self, rank0_only, offload_to_cpu,
                                    mixed_precision):
        mixed_precision = MixedPrecision() if mixed_precision else None
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
                mixed_precision=mixed_precision,
            ),
            mixed_precision=mixed_precision,
        )
        model = NestedWrappedModule(
            group=dist.distributed_c10d._get_default_group(),
            wrap_fsdp=False,
            fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
        )

        dev = (torch.device("cpu") if offload_to_cpu else torch.device(
            "cuda", torch.cuda.current_device()))

        params_to_compare = ([p.to(dev) for p in model.module.parameters()]
                             if not rank0_only or self.rank == 0 else list(
                                 p.clone() for p in fsdp_model.parameters()))
        with fsdp_model.summon_full_params(fsdp_model,
                                           rank0_only=rank0_only,
                                           writeback=not rank0_only):
            for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                                params_to_compare):
                self.assertEqual(p1, p2)

        # CPU offload should restore the param device
        param = next(fsdp_model.parameters())
        self.assertTrue(
            param.device == torch.device("cuda", torch.cuda.current_device()))
Exemple #14
0
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     wrapped_model = FSDP(
         model,
         self.process_group,
         ignored_modules=[model.transformer],
     )
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model: nn.Module = TransformerWithSharedParams.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
Exemple #15
0
 def check_weights(self, fsdp, expected_tensor_fn, check):
     with FSDP.summon_full_params(fsdp, recurse=True):
         linear_modules = [
             module for module in fsdp.modules() if type(module) == nn.Linear
         ]
         for module in linear_modules:
             for param in module.parameters():
                 expected = expected_tensor_fn(param)
                 check(param, expected, f"Got {param} but expected {expected}")
    def test_params_are_unflattenned(self):
        layer_shape = (10, 12)
        model = nn.Linear(*layer_shape, bias=False).cuda(self.rank)
        fsdp_model = FSDP(deepcopy(model)).cuda(self.rank)

        flattened_param = fsdp_model.get_parameter(
            "_fsdp_wrapped_module.flat_param")
        self.assertEqual(layer_shape[0] * layer_shape[1] / 2,
                         flattened_param.numel())

        with fsdp_model.summon_full_params():
            self.assertEqual(fsdp_model.weight.shape, model.weight.shape)
Exemple #17
0
def get_full_params(model: nn.Module, recurse: bool = True):
    """
    Returns the full unsharded parameters of ``model``. Any FSDP-managed
    parameters offloaded to CPU are moved to GPU in the returned list.

    Args:
        recurse (bool): If ``False``, only unshards the parameters immediate to
            ``model``; if ``True``, recurses through the module hierarchy
            rooted at ``model``.
    """
    with FSDP.summon_full_params(model, recurse=recurse):
        return deepcopy(list(model.parameters()))
Exemple #18
0
    def test_raises_rank0_with_writeback(self):
        fsdp_model = FSDP(
            NestedWrappedModule(
                group=dist.distributed_c10d._get_default_group(),
                wrap_fsdp=True,
                fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
            ))

        with self.assertRaisesRegex(ValueError, "is not supported"):
            with fsdp_model.summon_full_params(rank0_only=True,
                                               writeback=True):
                pass
Exemple #19
0
def _zero_model(
    model: nn.Module,
    zero_buffers: bool = False,
):
    """Zeros the parameters and optionally buffers of ``model`` in place."""
    with FSDP.summon_full_params(model):
        for param in model.parameters():
            with torch.no_grad():
                param.zero_()
        if zero_buffers:
            for buffer in model.buffers():
                with torch.no_grad():
                    buffer.zero_()
 def test_raises_rank0_with_writeback(self):
     """Tests that ``summon_full_params()`` with both ``rank0_only=True``
     and ``writeback=True`` raises an error."""
     nested_wrapped_module = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.RECURSIVE,
         CUDAInitMode.CUDA_BEFORE,
     )
     with self.assertRaisesRegex(ValueError, "is not supported"):
         with FSDP.summon_full_params(nested_wrapped_module,
                                      rank0_only=True,
                                      writeback=True):
             pass
    def test_summon_full_params_equivalence(self):
        offload = CPUOffload(offload_params=True)
        model = FSDP(DeterministicModel(wrap_fsdp=True, cpu_offload=offload),
                     cpu_offload=offload)
        local_model = DeterministicModel(wrap_fsdp=False)

        with model.summon_full_params(recurse=True):
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            fsdp_params = deepcopy(list(model.parameters()))

        self.assertEqual(fsdp_params, list(local_model.parameters()))
 def test_params_count_and_value(self):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         ))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     with fsdp_model.summon_full_params():
         for p1, p2 in itertools.zip_longest(fsdp_model.parameters(),
                                             model.module.parameters()):
             self.assertEqual(p1, p2)
 def test_state_dict_with_ignored_modules(self):
     # Initialize an FSDP-wrapped model with an ignored module that includes
     # both parameters and a buffer
     model = Model(wrap_fsdp=True, register_buffers=True).cuda()
     ignored_modules = [model.outer]
     ignored_tensor_to_tensor_name = {
         model.outer.bias: "outer.bias",
         model.outer.weight: "outer.weight",
         model.outer.buffer: "outer.buffer",
     }
     buffer_to_buffer_name = {
         model.inner.buffer: "inner.buffer", model.outer.buffer: "outer.buffer",
     }
     fsdp_model = FSDP(model, ignored_modules=ignored_modules)
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd1 = fsdp_model.state_dict()
     with FSDP.summon_full_params(fsdp_model):
         fsdp_params = deepcopy(list(fsdp_model.parameters()))
     # Check that the ignored parameters and all buffers are not cloned
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)
         self.assertEqual(tensor.data_ptr(), sd1[tensor_name].data_ptr())
     # Check that the state dict can be loaded into a non-wrapped version of
     # the model
     nonwrapped_model = Model(wrap_fsdp=False, register_buffers=True).cuda()
     for param in nonwrapped_model.parameters():
         with torch.no_grad():
             param.zero_()
     nonwrapped_model.load_state_dict(sd1)
     local_params = list(nonwrapped_model.parameters())
     for fsdp_param, local_param in zip(fsdp_params, local_params):
         self.assertEqual(fsdp_param, local_param)
     # Check that if we save a state dict again, the ignored parameters and
     # buffer still have the same data pointer
     with FSDP.state_dict_type(fsdp_model, StateDictType.FULL_STATE_DICT):
         sd2 = fsdp_model.state_dict()
     for tensor, tensor_name in {
         **ignored_tensor_to_tensor_name,
         **buffer_to_buffer_name,
     }.items():
         self.assertTrue(tensor_name in sd1)  # check again just in case
         self.assertTrue(tensor_name in sd2)
         self.assertEqual(tensor.data_ptr(), sd2[tensor_name].data_ptr())
         self.assertEqual(sd1[tensor_name].data_ptr(), sd2[tensor_name].data_ptr())
Exemple #24
0
    def test_state_dict_rank0_offload_save_load_flow(self):
        # Test taking checkpoint on rank 0 only, and reload
        # without redundant CPU memories.
        model = TransformerWithSharedParams(
            group=dist.distributed_c10d._get_default_group())
        my_auto_wrap_policy = partial(transformer_auto_wrap_policy,
                                      transformer_layer_cls={
                                          TransformerEncoderLayer,
                                          TransformerDecoderLayer
                                      })
        model = FSDP(model, auto_wrap_policy=my_auto_wrap_policy)
        ctx = self._get_state_dict_mgr(model, "state_dict", True)
        with ctx:
            state_dict = deepcopy(_get_state_dict(model))

        # All ranks initialize non-FSDP model
        grp = dist.distributed_c10d._get_default_group()
        model_new = TransformerWithSharedParams(group=grp)
        for p in model_new.parameters():
            with torch.no_grad():
                p.zero_()
        # Only rank 0 loads the checkpoint
        if self.rank == 0:
            model_new.load_state_dict(state_dict)

        # TransformerWithSharedParams has a buffer of zeros, so can't pass in
        # self.assertNotEqual since the buffers would be equal. So just checking that
        # there is some difference in the model across ranks before state_dict is
        # broadcasted.
        with self.assertRaisesRegex(AssertionError,
                                    "Tensor-likes are not close"):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)
        # FSDP with sync_module_states=True broadcasts the checkpointed states.
        model_new = FSDP(model_new,
                         device_id=torch.cuda.current_device(),
                         auto_wrap_policy=my_auto_wrap_policy,
                         sync_module_states=True)
        # After wrapping with FSDP models are equal across ranks, and have loaded the checkpoint
        with FSDP.summon_full_params(model_new):
            _validate(model_new, process_group=grp, assert_fn=self.assertEqual)

        with FullyShardedDataParallel.summon_full_params(model):
            with FullyShardedDataParallel.summon_full_params(model_new):
                params = list(model.parameters())
                params_new = list(model_new.parameters())
                self.assertEqual(params, params_new)
    def test_summon_single_param(self):
        model = FSDP(nn.Linear(1, 1, bias=False)).cuda(self.rank)

        p = model.get_parameter("_fsdp_wrapped_module.flat_param")
        self.assertEqual(1, p.numel())

        with torch.no_grad():
            # This sets the local shard value
            p[0] = self.rank + 2

        with model.summon_full_params(model, writeback=True):
            self.assertEqual(1, p.numel())
            with torch.no_grad():
                p.copy_(torch.zeros_like(p))

        # most ranks hold no data and wrote to padding so only rank zero will observe the above write
        if self.rank == 0:
            self.assertEqual(0, p[0])
        else:
            self.assertEqual(self.rank + 2, p[0])
Exemple #26
0
    def test_summon_full_params_equivalence(self, rank0_only, offload_to_cpu):
        offload = CPUOffload(offload_params=True)
        model = FSDP(
            DeterministicModel(wrap_fsdp=True, cpu_offload=offload), cpu_offload=offload
        )
        local_model = DeterministicModel(wrap_fsdp=False)

        params_to_compare = (
            [p.clone() for p in model.parameters()]
            if rank0_only and self.rank != 0
            else list(local_model.parameters())
        )

        writeback = not rank0_only

        with model.summon_full_params(
            model,
            recurse=True,
            rank0_only=rank0_only,
            writeback=writeback,
            offload_to_cpu=offload_to_cpu,
        ):
            if writeback:
                with torch.no_grad():
                    for p in model.parameters():
                        p.add_(1)
                    for p in params_to_compare:
                        p.add_(1)
            # Below sleep causes failures without stream synchronization in
            # summon_full_params fix.
            torch.cuda._sleep(1000000)
            # FSDP param deepcopy() of params has issues
            fsdp_params = [p.clone() for p in model.parameters()]

        self.assertEqual(fsdp_params, params_to_compare)

        # CPU offload is enabled for main API, so we should point back to CPU
        for param in model.parameters():
            self.assertEqual(param.device, torch.device("cpu"))
Exemple #27
0
 def test_ignored_modules_nested(self):
     """Tests that passing a module with nested FSDP modules does not
     error and still ignores non-FSDP modules' parameters."""
     # Initialize an FSDP-wrapped nested model that first wraps the nested
     # sequential's second linear layer (`layer1[1]`) and then wraps the
     # overall model while ignoring the nested sequential (`layer1`)
     model = Model().cuda()
     model.layer1[1] = FSDP(model.layer1[1])
     wrapped_model = FSDP(model, ignored_modules=[model.layer1])
     # Check that the wrapped model's flattened parameter does not include
     # the ignored nested sequential's parameters
     nonwrapped_model = Model()
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.layer1.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
Exemple #28
0
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         )
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     model.register_buffer("buffer", torch.ones(1))
     with fsdp_model.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                 getattr(fsdp_model, call)(prefix=prefix, recurse=recurse),
                 getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
 def test_ignored_modules_transformer(self):
     """Tests that ignored modules' parameters are not flattened for a
     transformer model with shared parameters."""
     # Initialize an FSDP-wrapped transformer model that has FSDP ignore
     # the `nn.Transformer` module's parameters
     group = dist.distributed_c10d._get_default_group()
     wrapped_model = self._get_wrapped_model(
         group,
         cuda_first=True,
         ignore_modules=True,
     )
     # Check that the wrapped model's flattened parameter does not include
     # the ignored transformer module's parameters
     nonwrapped_model = self._get_nonwrapped_model(group)
     total_numel = sum(p.numel() for p in nonwrapped_model.parameters())
     ignored_numel = sum(p.numel()
                         for p in nonwrapped_model.transformer.parameters())
     nonignored_numel = total_numel - ignored_numel
     with FSDP.summon_full_params(wrapped_model):
         flat_param_numel = wrapped_model.params[0].numel()
         self.assertEqual(flat_param_numel, nonignored_numel)
     # Check that we can run a few iterations
     optim = torch.optim.Adam(wrapped_model.parameters(), lr=1e-3)
     self._train_model(wrapped_model, optim, 3)
    def test_summon_from_non_fsdp(self):
        class FSDPContainer(nn.Module):
            def __init__(self, fsdp_1, fsdp_2, fsdp_3):
                super().__init__()
                self.fsdp_1 = fsdp_1
                self.fsdp_2 = fsdp_2
                self.fsdp_3 = fsdp_3

        model_fsdp = FSDPContainer(
            FSDP(DeterministicModel(wrap_fsdp=True)),
            FSDP(DeterministicModel(wrap_fsdp=True)),
            DeterministicModel(wrap_fsdp=False),
        )
        model_no_fsdp = FSDPContainer(
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
            DeterministicModel(wrap_fsdp=False),
        )

        params_to_compare = list(model_no_fsdp.parameters())
        with FSDP.summon_full_params(model_fsdp):
            fsdp_params = [p.clone() for p in model_fsdp.parameters()]

        self.assertEqual(params_to_compare, fsdp_params)