class Model(Module):
    def __init__(self, wrap_fsdp, register_buffer=False):
        super().__init__()
        self.inner = Linear(*INNER_SHAPE)
        if wrap_fsdp:
            self.inner = FSDP(self.inner)
        self.outer = Linear(*OUTER_SHAPE)
        if register_buffer:
            self.outer.register_buffer("buffer", torch.randn(BUFFER_SHAPE))

    def forward(self, x):
        # Forward twice.
        i = self.inner(x)
        j = self.inner(x)
        return self.outer(i + j)