class Decoder(nn.Module):
    """Base class for transformer decoders

    Attributes:
        langpair: Language pair to translate. Necessary due to the padding idx of tokenizer.
    """
    def __init__(self, langpair: str, is_base: bool = True) -> None:
        super().__init__()
        self.embedding = Embeddings(langpair)
        self.config = Config()
        self.config.add_model(is_base)
        self.num_layers = self.config.model.model_params.num_decoder_layer
        self.decoder_layers = get_clones(DecoderLayer(), self.num_layers)

    def forward(
        self,
        target_tokens: Tensor,
        target_mask: Tensor,
        encoder_out: Tensor,
        encoder_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
            target_emb: input to the decoder layer (batch_size, seq_len, dim_model)
            target_mask: padding mask of the target embedding
            encoder_out: the last encoder layer's output (batch_size, seq_len, dim_model)
            encoder_mask: boolean Tensor where padding elements are indicated by False (batch_size, seq_len)
        """
        target_emb = self.embedding(target_tokens)
        for i in range(self.num_layers):
            target_emb, target_mask = self.decoder_layers[i](target_emb,
                                                             target_mask,
                                                             encoder_out,
                                                             encoder_mask)
        return target_emb, target_mask
예제 #2
0
class Encoder(nn.Module):
    """Transformer encoder consisting of EncoderLayers

    Attributes:
        langpair: Language pair to translate. Necessary due to the padding idx of tokenizer.
    """
    def __init__(self, langpair: str, is_base: bool = True) -> None:
        super().__init__()
        self.embedding = Embeddings(langpair)
        self.config = Config()
        self.config.add_model(is_base)
        self.num_layers = self.config.model.model_params.num_encoder_layer
        self.encoder_layers = get_clones(EncoderLayer(), self.num_layers)

    def forward(self, source_tokens: Tensor,
                source_mask: Tensor) -> NamedTuple:
        source_emb = self.embedding(source_tokens)
        for i in range(self.num_layers):
            source_emb, source_mask = self.encoder_layers[i](source_emb,
                                                             source_mask)
        return EncoderOut(
            encoder_out=source_emb,
            encoder_mask=source_mask,
            source_tokens=source_tokens,
        )
    def __init__(self, langpair: str, is_base: bool = True) -> None:
        super().__init__()
        configs = Config()
        configs.add_tokenizer(langpair)
        configs.add_model(is_base)
        dim_model: int = configs.model.model_params.dim_model
        vocab_size = configs.tokenizer.vocab_size

        self.encoder = Encoder(langpair)
        self.decoder = Decoder(langpair)
        self.linear = nn.Linear(dim_model, vocab_size)
class DecoderLayer(nn.Module):
    """Decoder layer block"""
    def __init__(self, is_base: bool = True):
        super().__init__()
        self.config = Config()
        self.config.add_model(is_base)

        self.masked_mha = MultiHeadAttention(masked_attention=True)
        self.mha = MultiHeadAttention(masked_attention=False)
        self.ln = LayerNorm(self.config.model.train_hparams.eps)
        self.ffn = FeedForwardNetwork()
        self.residual_dropout = nn.Dropout(
            p=self.config.model.model_params.dropout)

    def attention_mask(self, batch_size: int, seq_len: int) -> Tensor:
        attention_shape = (batch_size, seq_len, seq_len)
        attention_mask = np.triu(np.ones(attention_shape), k=1).astype("unit8")
        attention_mask = torch.from_numpy(attention_mask) == 0
        return attention_mask  # (batch_size, seq_len, seq_len)

    def forward(
        self,
        target_emb: Tensor,
        target_mask: Tensor,
        encoder_out: Tensor,
        encoder_mask: Optional[Tensor] = None,
    ) -> Tuple[Tensor, Tensor]:
        """
        Args:
            target_emb: input to the decoder layer (batch_size, seq_len, dim_model)
            target_mask: padding mask of the target embedding
            encoder_out: the last encoder layer's output (batch_size, seq_len, dim_model)
            encoder_mask: boolean Tensor where padding elements are indicated by False (batch_size, seq_len)
        """
        attention_mask = self.attention_mask(target_emb.size(0),
                                             target_emb.size(1))
        target_emb = target_emb + self.masked_mha(
            query=target_emb,
            key=target_emb,
            value=target_emb,
            attention_mask=attention_mask,
        )
        target_emb = self.ln(target_emb)
        target_emb = target_emb + self.mha(
            query=target_emb, key=encoder_out, value=encoder_out)
        target_emb = self.ln(target_emb)
        target_emb = target_emb + self.ffn(target_emb)
        return target_emb, target_mask
예제 #5
0
class EncoderLayer(nn.Module):
    """Encoder layer block"""
    def __init__(self, is_base: bool = True):
        super().__init__()
        self.config = Config()
        self.config.add_model(is_base)

        self.mha = MultiHeadAttention(masked_attention=False)
        self.attention_dropout = nn.Dropout(
            p=self.config.model.model_params.dropout)
        self.ln = LayerNorm(self.config.model.train_hparams.eps)
        self.ffn = FeedForwardNetwork()
        self.residual_dropout = nn.Dropout(
            p=self.config.model.model_params.dropout)

    def forward(self, source_emb: Tensor,
                source_mask: Tensor) -> Tuple[Tensor, Tensor]:
        source_emb = source_emb + self.mha(
            query=source_emb, key=source_emb, value=source_emb)
        source_emb = self.attention_dropout(source_emb)
        source_emb = self.ln(source_emb)
        source_emb = source_emb + self.residual_dropout(self.ffn(source_emb))
        source_emb = self.ln(source_emb)
        return source_emb, source_mask
class WMT14DataLoader(LightningDataModule):
    """Load WMT14 dataset and prepare batches for training and testing transformer.

    Attributes:
        langpair: language pair to translate
    """
    def __init__(self, langpair: str, is_base: bool = True) -> None:
        super().__init__()
        self.configs = Config()
        self.configs.add_data(langpair)
        self.configs.add_model(is_base)
        self.langpair = langpair
        self.max_length = self.configs.model.model_params.max_len

    def setup(self, stage: Optional[str] = None) -> None:
        """Assign dataset for use in dataloaders

        Args:
            stage: decide to load train/val or test
        """
        # Assign train/val datasets for use in dataloaders
        if stage == "fit" or stage is None:
            self.train_dataset = WMT14Dataset(self.langpair,
                                              max_length=self.max_length,
                                              mode="train")
            self.valid_dataset = WMT14Dataset(self.langpair,
                                              max_length=self.max_length,
                                              mode="val")
        # Assign test dataset for use in dataloaders
        if stage == "test" or stage is None:
            self.test_dataset = WMT14Dataset(self.langpair,
                                             max_length=self.max_length,
                                             mode="test")

    def batch_by_tokens(
            self,
            dataset: Dataset,
            max_tokens: Optional[int] = None) -> List[torch.Tensor]:
        """Create mini-batch tensors by number of tokens

        Args:
            dataset: source and target dataset containing padded_token, mask, and length
                     e.g.,
                     {'source': {'padded_token': torch.Tensor, 'mask': torch.Tensor, 'length': torch.Tensor},
                      'target': {'padded_token': torch.Tensor, 'mask': torch.Tensor, 'length': torch.Tensor}}
            max_tokens: max number of tokens per batch

        Returns:
            indices_batches:
        """
        max_tokens = (25000 if max_tokens is None else
                      self.configs.model.train_hparams.batch_size)

        start_idx = 0
        source_sample_lens, target_sample_lens = [], []
        indices_batches = []
        for end_idx in range(len(dataset)):
            source_sample_lens.append(dataset[end_idx]["source"]["length"])
            target_sample_lens.append(dataset[end_idx]["target"]["length"])
            # when batch is full
            if (sum(source_sample_lens) > max_tokens
                    or sum(target_sample_lens) > max_tokens):
                indices_batch = torch.arange(start_idx, end_idx)
                indices_batches.append(indices_batch)
                start_idx = end_idx
                source_sample_lens, target_sample_lens = [
                    source_sample_lens[-1]
                ], [target_sample_lens[-1]]  # end_idx is not included
            # when iteration ends
            elif end_idx == len(dataset):
                indices_batch = torch.arange(start_idx, end_idx)
                indices_batches.append(indices_batch)
        return indices_batches

    # TODO: batch together by approx. sequence length.
    def train_dataloader(self) -> DataLoader:
        batch_sampler = self.batch_by_tokens(self.train_dataset)
        return DataLoader(
            self.train_dataset,
            batch_sampler=batch_sampler,
            shuffle=False,
            drop_last=False,
            num_workers=self.configs.model.data_params.num_workers,
        )

    def valid_dataloader(self) -> DataLoader:
        batch_sampler = self.batch_by_tokens(self.valid_dataset)
        return DataLoader(
            self.valid_dataset,
            batch_sampler=batch_sampler,
            shuffle=False,
            drop_last=False,
            num_workers=self.configs.model.data_params.num_workers,
        )

    def test_dataloader(self) -> DataLoader:
        batch_sampler = self.batch_by_tokens(self.test_dataset)
        return DataLoader(
            self.test_dataset,
            batch_sampler=batch_sampler,
            shuffle=False,
            drop_last=False,
            num_workers=self.configs.model.data_params.num_workers,
        )