class Model(Module):
    def __init__(self, with_fsdp=False, inner_flat=False, sharing=None):
        super().__init__()
        self.l0 = Linear(4, 4, bias=True).cuda()
        self.l1 = Linear(4, 4, bias=True).cuda()
        self.l2 = Linear(4, 4, bias=True).cuda()
        self.l3 = Linear(4, 4, bias=True).cuda()

        # share the weights. the layer must have at least 1 param is that's not
        # shared. Therefore, we have bias=True and testing either sharing the
        # weight or the bias.
        if sharing == "share_only_weights":
            self.l1.weight = self.l3.weight
        elif sharing == "share_only_bias":
            self.l1.bias = self.l3.bias
        else:
            assert sharing is None or sharing == "share_none"

        if with_fsdp:
            # Shared layers much be un-flatten.
            self.l1 = FSDP(self.l1, flatten_parameters=False)
            self.l2 = FSDP(self.l2, flatten_parameters=inner_flat)
            self.l3 = FSDP(self.l3, flatten_parameters=False)

            if sharing in ["share_only_weights"]:
                self.l3.append_shared_param(self.l1.module.weight)
            if sharing in ["share_only_bias"]:
                self.l3.append_shared_param(self.l1.module.bias)

    def forward(self, x):
        x = self.l0(x)
        x = self.l1(x)
        x = self.l2(x)
        x = self.l3(x)
        return x
class Model(nn.Module):
    def __init__(self, with_fsdp=False, wrap_middle="none"):
        super().__init__()
        self.l0 = nn.Embedding(VOCAB, D_MODEL).cuda().half()
        nn.init.uniform_(self.l0.weight, -1.0e-1, 1.0e-1)
        self.l1 = MEVO(self.l0.weight, tile_factor=TILE, reduction="sum")
        self.middle = nn.Linear(D_MODEL, D_MODEL).cuda().half()
        # LNs are not strictly needed for this test, but they help reduce the loss quickly
        # and improves the numerical stability.
        self.ln1 = nn.LayerNorm(D_MODEL).cuda().half()
        self.ln2 = nn.LayerNorm(D_MODEL).cuda().half()

        if with_fsdp:
            # Shared layers must be un-flatten.
            self.l0 = FSDP(self.l0,
                           flatten_parameters=False,
                           mixed_precision=False,
                           compute_dtype=torch.float16)
            self.l1 = FSDP(self.l1,
                           flatten_parameters=False,
                           mixed_precision=False,
                           compute_dtype=torch.float16)
            self.l1.append_shared_param(self.l0.module.weight)
            # These are for debugging.
            # print(id(self.l0), "is emb")
            # print(id(self.l1), "is out")
            assert wrap_middle in ["none", "flat", "nonflat"]
            if wrap_middle != "none":
                self.middle = FSDP(
                    self.middle,
                    flatten_parameters=wrap_middle == "flat",
                    mixed_precision=False,
                    compute_dtype=torch.float16,
                )
                # print(id(self.middle), "is middle")

    def forward(self, x):
        target = x + 1
        x = self.l0(x)
        x = self.ln1(x)
        x = self.middle(x)
        x = self.ln2(x)
        x = self.l1(x, target)
        print("LOSS", x.item())
        assert x.item() not in [float("-inf"), float("inf")]
        return x