def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     """Tests that ``named_parameters()`` and ``named_buffers()`` for a
     top-level FSDP-wrapped model matches their behavior for the equivalent
     non-wrapped model."""
     model = NestedWrappedModule.init(
         self.process_group,
         FSDPInitMode.NO_FSDP,
         CUDAInitMode.CUDA_BEFORE,
         deterministic=True,
     )
     model.register_buffer("buffer", torch.ones(1))
     # `named_parameters()` and `named_buffers` will contain FSDP prefixes
     # if called on a non-FSDP root module
     fsdp_model = FSDP(
         NestedWrappedModule.init(
             self.process_group,
             FSDPInitMode.NO_FSDP,
             CUDAInitMode.CUDA_BEFORE,
             deterministic=True,
         ),
         self.process_group,
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     with FSDP.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                     getattr(fsdp_model, call)(prefix=prefix,
                                               recurse=recurse),
                     getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)
class Model(Module):
    def __init__(self, wrap_fsdp, register_buffers=False):
        super().__init__()
        self.inner = Linear(*INNER_SHAPE)
        if register_buffers:
            self.inner.register_buffer("buffer", torch.randn(BUFFER_SHAPE))
        if wrap_fsdp:
            self.inner = FSDP(self.inner)
        self.outer = Linear(*OUTER_SHAPE)
        if register_buffers:
            self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))

    def forward(self, x):
        # Forward twice.
        i = self.inner(x)
        j = self.inner(x)
        return self.outer(i + j)
Exemple #3
0
 def test_named_parameters_buffers(self, prefix: str, recurse: bool):
     fsdp_model = FSDP(
         NestedWrappedModule(
             group=dist.distributed_c10d._get_default_group(),
             wrap_fsdp=True,
             fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
         )
     )
     fsdp_model.register_buffer("buffer", torch.ones(1))
     model = NestedWrappedModule(
         group=dist.distributed_c10d._get_default_group(),
         wrap_fsdp=False,
         fsdp_init_mode=FSDPInitMode.CUDA_BEFORE,
     )
     model.register_buffer("buffer", torch.ones(1))
     with fsdp_model.summon_full_params(fsdp_model):
         for call in ["named_parameters", "named_buffers"]:
             for (n1, p1), (n2, p2) in itertools.zip_longest(
                 getattr(fsdp_model, call)(prefix=prefix, recurse=recurse),
                 getattr(model, call)(prefix=prefix, recurse=recurse),
             ):
                 self.assertEqual(n1, n2)
                 self.assertEqual(p1, p2)