Beispiel #1
0
    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
        # Length of the memory
        m_len = len(mem[0]) if mem else 0
        # Create a subsequent mask for tokens
        if self.mask_x is None or self.mask_x.shape[0] < len(x):
            from labml_nn.transformers.utils import subsequent_mask
            self.mask_x = subsequent_mask(len(x)).to(x.device)
        # Create an all ones (full visibility) mask for memory
        if self.mask_mem is None or self.mask_mem.shape[
                1] < m_len or self.mask_mem.shape[0] < len(x):
            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

        # Concatenate the masks if there is memory
        if m_len:
            mask = torch.cat((self.mask_mem[:len(x), :m_len],
                              self.mask_x[:len(x), :len(x)]),
                             dim=1)
        # Use the subsequent mask otherwise
        else:
            mask = self.mask_x[:len(x), :len(x)]

        # Token embeddings
        x = self.src_embed(x)
        # Run it through the transformer
        res, mem = self.transformer(x, mem, mask)
        # Generate logits of the next token
        res = self.generator(res)
        #
        return res, mem
Beispiel #2
0
 def forward(self, src: torch.Tensor):
     # Create subsequent mask, so that the transformer can only pay attention to past tokens.
     if self.src_mask is None or self.src_mask.size(0) != len(src):
         self.src_mask = subsequent_mask(len(src)).to(src.device)
     # Embed the tokens (`src`) and run it through the the transformer
     res = self.encoder(self.src_embed(src), self.src_mask)
     # Generate logits of the next token
     return self.generator(res)
Beispiel #3
0
    def forward(self, x: torch.Tensor):
        # Create a mask if we haven't created or sizes have changed
        if self.mask is None or self.mask.size(0) != len(x):
            # [Subsequent mask](../utils.html), will mask out tokens from seeing future tokens
            self.mask = subsequent_mask(len(x)).to(x.device)

        #
        return self.mask
Beispiel #4
0
    def __call__(self, src: torch.Tensor, _: Any = None):
        if self.src_mask is None or self.src_mask.size(0) != len(src):
            self.src_mask = subsequent_mask(len(src)).to(src.device)

        src = self.src_embed(src)
        # with monit.section("transformer"):
        output = self.encoder(src, self.src_mask)
        output = self.fc(output)
        return output, None
Beispiel #5
0
 def forward(self, x: torch.Tensor):
     # Initialize the subsequent mask
     if self.mask is None or self.mask.size(0) != len(x):
         from labml_nn.transformers.utils import subsequent_mask
         self.mask = subsequent_mask(len(x)).to(x.device)
     # Token embeddings
     x = self.src_embed(x)
     # Run it through the transformer
     res, counts, route_prob, n_dropped = self.transformer(x, self.mask)
     # Generate logits of the next token
     res = self.generator(res)
     #
     return res, counts, route_prob, n_dropped
Beispiel #6
0
    def __call__(self, x: torch.Tensor):
        # Create subsequent mask if mask is not initialized
        # or if the size of the mask is different
        if self.mask is None or self.mask.size(0) != len(x):
            # Subsequent mask, will mask out tokens from seeing future tokens
            self.mask = subsequent_mask(len(x)).to(x.device)
        # Get the token embeddings with positional encodings
        x = self.src_embed(x)
        # Transformer encoder
        x = self.encoder(x, self.mask)
        # Get logits
        x = self.generator(x)

        # Return results
        # (second value is for state, since our trainer is used with RNNs also)
        return x, None
Beispiel #7
0
 def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
     # Initialize the subsequent mask
     length = len(x)
     if mem:
         length += len(mem[0])
     if self.mask is None or self.mask.size(0) != length:
         from labml_nn.transformers.utils import subsequent_mask
         self.mask = subsequent_mask(length).to(x.device)
     # Token embeddings
     x = self.src_embed(x)
     # Run it through the transformer
     res, mem = self.transformer(x, mem, self.mask[:len(x), :length, :])
     # Generate logits of the next token
     res = self.generator(res)
     #
     return res, mem
Beispiel #8
0
    def forward(self, x: torch.Tensor):
        """
        :param x: are the embeddings of shape `[seq_len, batch_size, d_model]`
        """
        # Create causal mask
        if self.mask is None or self.mask.size(0) != len(x):
            # Subsequent mask, will mask out tokens from seeing future tokens
            self.mask = subsequent_mask(len(x)).to(x.device)

        # Run through self attention, i.e. keys and values are from self
        x = self.self_attn_norm(x, self.self_attn(query=x, key=x, value=x, mask=self.mask))
        # Pass through the feed-forward network
        x = self.feed_forward_norm(x, self.feed_forward(x))

        #
        return x
Beispiel #9
0
    def forward(self, x: torch.Tensor):
        """
        :param x: are the input tokens of shape `[seq_len, batch_size]`
        """
        # Create auto-regressive mask
        if self.mask is None or self.mask.size(0) != len(x):
            # Subsequent mask, will mask out tokens from seeing future tokens
            self.mask = subsequent_mask(len(x)).to(x.device)

        # Get the token embeddings
        x = self.emb(x)
        # Transformer encoder
        for layer in self.transformer_layers:
            x = layer(x=x, mask=self.mask)
        # Get logits
        x = self.readout(x)

        # Return results
        return x, None
Beispiel #10
0
    def __call__(self, x: torch.Tensor, mem: Optional[List[torch.Tensor]]):
        m_len = len(mem[0]) if mem else 0
        if self.mask_x is None or self.mask_x.shape[0] < len(x):
            from labml_nn.transformers.utils import subsequent_mask
            self.mask_x = subsequent_mask(len(x)).to(x.device)
        if self.mask_mem is None or self.mask_mem.shape[
                1] < m_len or self.mask_mem.shape[0] < len(x):
            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

        if m_len:
            mask = torch.cat((self.mask_mem[:len(x), :m_len],
                              self.mask_x[:len(x), :len(x)]),
                             dim=1)
        else:
            mask = self.mask_x[:len(x), :len(x)]

        x = self.src_embed(x)
        res, mem = self.transformer(x, mem, mask)
        res = self.generator(res)

        return res, mem
Beispiel #11
0
    def forward(self, x: torch.Tensor, mem: List[torch.Tensor]):
        # Initialize the subsequent mask
        m_len = len(mem[0]) if mem else 0
        if self.mask_x is None or self.mask_x.shape[0] < len(x):
            from labml_nn.transformers.utils import subsequent_mask
            self.mask_x = subsequent_mask(len(x)).to(x.device)
        if self.mask_mem is None or self.mask_mem.shape[
                1] < m_len or self.mask_mem.shape[0] < len(x):
            self.mask_mem = self.mask_x.new_ones(len(x), m_len, 1)

        if m_len:
            mask = torch.cat((self.mask_mem[:len(x), :m_len],
                              self.mask_x[:len(x), :len(x)]),
                             dim=1)
        else:
            mask = self.mask_x[:len(x), :len(x)]
        # Token embeddings
        x = self.src_embed(x)
        # Run it through the transformer
        res, mem = self.transformer(x, mem, mask)
        # Generate logits of the next token
        res = self.generator(res)
        #
        return res, mem