class Model(Module): def __init__(self, with_fsdp=False, inner_flat=False, sharing=None): super().__init__() self.l0 = Linear(4, 4, bias=True).cuda() self.l1 = Linear(4, 4, bias=True).cuda() self.l2 = Linear(4, 4, bias=True).cuda() self.l3 = Linear(4, 4, bias=True).cuda() # share the weights. the layer must have at least 1 param is that's not # shared. Therefore, we have bias=True and testing either sharing the # weight or the bias. if sharing == "share_only_weights": self.l1.weight = self.l3.weight elif sharing == "share_only_bias": self.l1.bias = self.l3.bias else: assert sharing is None or sharing == "share_none" if with_fsdp: # Shared layers much be un-flatten. self.l1 = FSDP(self.l1, flatten_parameters=False) self.l2 = FSDP(self.l2, flatten_parameters=inner_flat) self.l3 = FSDP(self.l3, flatten_parameters=False) if sharing in ["share_only_weights"]: self.l3.append_shared_param(self.l1.module.weight) if sharing in ["share_only_bias"]: self.l3.append_shared_param(self.l1.module.bias) def forward(self, x): x = self.l0(x) x = self.l1(x) x = self.l2(x) x = self.l3(x) return x
class Model(nn.Module): def __init__(self, with_fsdp=False, wrap_middle="none"): super().__init__() self.l0 = nn.Embedding(VOCAB, D_MODEL).cuda().half() nn.init.uniform_(self.l0.weight, -1.0e-1, 1.0e-1) self.l1 = MEVO(self.l0.weight, tile_factor=TILE, reduction="sum") self.middle = nn.Linear(D_MODEL, D_MODEL).cuda().half() # LNs are not strictly needed for this test, but they help reduce the loss quickly # and improves the numerical stability. self.ln1 = nn.LayerNorm(D_MODEL).cuda().half() self.ln2 = nn.LayerNorm(D_MODEL).cuda().half() if with_fsdp: # Shared layers must be un-flatten. self.l0 = FSDP(self.l0, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1 = FSDP(self.l1, flatten_parameters=False, mixed_precision=False, compute_dtype=torch.float16) self.l1.append_shared_param(self.l0.module.weight) # These are for debugging. # print(id(self.l0), "is emb") # print(id(self.l1), "is out") assert wrap_middle in ["none", "flat", "nonflat"] if wrap_middle != "none": self.middle = FSDP( self.middle, flatten_parameters=wrap_middle == "flat", mixed_precision=False, compute_dtype=torch.float16, ) # print(id(self.middle), "is middle") def forward(self, x): target = x + 1 x = self.l0(x) x = self.ln1(x) x = self.middle(x) x = self.ln2(x) x = self.l1(x, target) print("LOSS", x.item()) assert x.item() not in [float("-inf"), float("inf")] return x