def forward(self, x: torch.LongTensor, pad_mask: torch.BoolTensor = None, pos: torch.LongTensor = None): assert x.shape[ -1] == self.seq_len, f"Expected sequence length '{self.seq_len}'! Got {x.shape[-1]}" if not pos: pos = torch.arange(self.seq_len, dtype=torch.long).to( x.device).unsqueeze(0).expand_as(x) # x = self.input_embedding(x) # x = self.pos_embedding(pos) # x = self.dropout(x) # x = self.input_embedding(x) + self. x = self.input_embedding(x) + self.pos_embedding(pos) for b in self.blocks: x = b(x, pad_mask=pad_mask) x = self.norm(x) x = x.matmul(self.input_embedding.weight.t()) prob = torch.softmax(x, dim=-1) return prob, x