def forward(self, x: torch.Tensor, mem: List[torch.Tensor]): # Length of the memory m_len = len(mem[0]) if mem else 0 # Create a subsequent mask for tokens if self.mask_x is None or self.mask_x.shape[0] < len(x): from labml_nn.transformers.utils import subsequent_mask self.mask_x = subsequent_mask(len(x)).to(x.device) # Create an all ones (full visibility) mask for memory if self.mask_mem is None or self.mask_mem.shape[ 1] < m_len or self.mask_mem.shape[0] < len(x): self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) # Concatenate the masks if there is memory if m_len: mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) # Use the subsequent mask otherwise else: mask = self.mask_x[:len(x), :len(x)] # Token embeddings x = self.src_embed(x) # Run it through the transformer res, mem = self.transformer(x, mem, mask) # Generate logits of the next token res = self.generator(res) # return res, mem
def forward(self, src: torch.Tensor): # Create subsequent mask, so that the transformer can only pay attention to past tokens. if self.src_mask is None or self.src_mask.size(0) != len(src): self.src_mask = subsequent_mask(len(src)).to(src.device) # Embed the tokens (`src`) and run it through the the transformer res = self.encoder(self.src_embed(src), self.src_mask) # Generate logits of the next token return self.generator(res)
def forward(self, x: torch.Tensor): # Create a mask if we haven't created or sizes have changed if self.mask is None or self.mask.size(0) != len(x): # [Subsequent mask](../utils.html), will mask out tokens from seeing future tokens self.mask = subsequent_mask(len(x)).to(x.device) # return self.mask
def __call__(self, src: torch.Tensor, _: Any = None): if self.src_mask is None or self.src_mask.size(0) != len(src): self.src_mask = subsequent_mask(len(src)).to(src.device) src = self.src_embed(src) # with monit.section("transformer"): output = self.encoder(src, self.src_mask) output = self.fc(output) return output, None
def forward(self, x: torch.Tensor): # Initialize the subsequent mask if self.mask is None or self.mask.size(0) != len(x): from labml_nn.transformers.utils import subsequent_mask self.mask = subsequent_mask(len(x)).to(x.device) # Token embeddings x = self.src_embed(x) # Run it through the transformer res, counts, route_prob, n_dropped = self.transformer(x, self.mask) # Generate logits of the next token res = self.generator(res) # return res, counts, route_prob, n_dropped
def __call__(self, x: torch.Tensor): # Create subsequent mask if mask is not initialized # or if the size of the mask is different if self.mask is None or self.mask.size(0) != len(x): # Subsequent mask, will mask out tokens from seeing future tokens self.mask = subsequent_mask(len(x)).to(x.device) # Get the token embeddings with positional encodings x = self.src_embed(x) # Transformer encoder x = self.encoder(x, self.mask) # Get logits x = self.generator(x) # Return results # (second value is for state, since our trainer is used with RNNs also) return x, None
def forward(self, x: torch.Tensor, mem: List[torch.Tensor]): # Initialize the subsequent mask length = len(x) if mem: length += len(mem[0]) if self.mask is None or self.mask.size(0) != length: from labml_nn.transformers.utils import subsequent_mask self.mask = subsequent_mask(length).to(x.device) # Token embeddings x = self.src_embed(x) # Run it through the transformer res, mem = self.transformer(x, mem, self.mask[:len(x), :length, :]) # Generate logits of the next token res = self.generator(res) # return res, mem
def forward(self, x: torch.Tensor): """ :param x: are the embeddings of shape `[seq_len, batch_size, d_model]` """ # Create causal mask if self.mask is None or self.mask.size(0) != len(x): # Subsequent mask, will mask out tokens from seeing future tokens self.mask = subsequent_mask(len(x)).to(x.device) # Run through self attention, i.e. keys and values are from self x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask)) # Pass through the feed-forward network x = self.feed_forward_norm(x, self.feed_forward(x)) # return x
def forward(self, x: torch.Tensor): """ :param x: are the input tokens of shape `[seq_len, batch_size]` """ # Create auto-regressive mask if self.mask is None or self.mask.size(0) != len(x): # Subsequent mask, will mask out tokens from seeing future tokens self.mask = subsequent_mask(len(x)).to(x.device) # Get the token embeddings x = self.emb(x) # Transformer encoder for layer in self.transformer_layers: x = layer(x=x, mask=self.mask) # Get logits x = self.readout(x) # Return results return x, None
def __call__(self, x: torch.Tensor, mem: Optional[List[torch.Tensor]]): m_len = len(mem[0]) if mem else 0 if self.mask_x is None or self.mask_x.shape[0] < len(x): from labml_nn.transformers.utils import subsequent_mask self.mask_x = subsequent_mask(len(x)).to(x.device) if self.mask_mem is None or self.mask_mem.shape[ 1] < m_len or self.mask_mem.shape[0] < len(x): self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) if m_len: mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) else: mask = self.mask_x[:len(x), :len(x)] x = self.src_embed(x) res, mem = self.transformer(x, mem, mask) res = self.generator(res) return res, mem
def forward(self, x: torch.Tensor, mem: List[torch.Tensor]): # Initialize the subsequent mask m_len = len(mem[0]) if mem else 0 if self.mask_x is None or self.mask_x.shape[0] < len(x): from labml_nn.transformers.utils import subsequent_mask self.mask_x = subsequent_mask(len(x)).to(x.device) if self.mask_mem is None or self.mask_mem.shape[ 1] < m_len or self.mask_mem.shape[0] < len(x): self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1) if m_len: mask = torch.cat((self.mask_mem[:len(x), :m_len], self.mask_x[:len(x), :len(x)]), dim=1) else: mask = self.mask_x[:len(x), :len(x)] # Token embeddings x = self.src_embed(x) # Run it through the transformer res, mem = self.transformer(x, mem, mask) # Generate logits of the next token res = self.generator(res) # return res, mem