class Decoder(nn.Module): """Base class for transformer decoders Attributes: langpair: Language pair to translate. Necessary due to the padding idx of tokenizer. """ def __init__(self, langpair: str, is_base: bool = True) -> None: super().__init__() self.embedding = Embeddings(langpair) self.config = Config() self.config.add_model(is_base) self.num_layers = self.config.model.model_params.num_decoder_layer self.decoder_layers = get_clones(DecoderLayer(), self.num_layers) def forward( self, target_tokens: Tensor, target_mask: Tensor, encoder_out: Tensor, encoder_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Args: target_emb: input to the decoder layer (batch_size, seq_len, dim_model) target_mask: padding mask of the target embedding encoder_out: the last encoder layer's output (batch_size, seq_len, dim_model) encoder_mask: boolean Tensor where padding elements are indicated by False (batch_size, seq_len) """ target_emb = self.embedding(target_tokens) for i in range(self.num_layers): target_emb, target_mask = self.decoder_layers[i](target_emb, target_mask, encoder_out, encoder_mask) return target_emb, target_mask
class Encoder(nn.Module): """Transformer encoder consisting of EncoderLayers Attributes: langpair: Language pair to translate. Necessary due to the padding idx of tokenizer. """ def __init__(self, langpair: str, is_base: bool = True) -> None: super().__init__() self.embedding = Embeddings(langpair) self.config = Config() self.config.add_model(is_base) self.num_layers = self.config.model.model_params.num_encoder_layer self.encoder_layers = get_clones(EncoderLayer(), self.num_layers) def forward(self, source_tokens: Tensor, source_mask: Tensor) -> NamedTuple: source_emb = self.embedding(source_tokens) for i in range(self.num_layers): source_emb, source_mask = self.encoder_layers[i](source_emb, source_mask) return EncoderOut( encoder_out=source_emb, encoder_mask=source_mask, source_tokens=source_tokens, )
def __init__(self, langpair: str, is_base: bool = True) -> None: super().__init__() configs = Config() configs.add_tokenizer(langpair) configs.add_model(is_base) dim_model: int = configs.model.model_params.dim_model vocab_size = configs.tokenizer.vocab_size self.encoder = Encoder(langpair) self.decoder = Decoder(langpair) self.linear = nn.Linear(dim_model, vocab_size)
class DecoderLayer(nn.Module): """Decoder layer block""" def __init__(self, is_base: bool = True): super().__init__() self.config = Config() self.config.add_model(is_base) self.masked_mha = MultiHeadAttention(masked_attention=True) self.mha = MultiHeadAttention(masked_attention=False) self.ln = LayerNorm(self.config.model.train_hparams.eps) self.ffn = FeedForwardNetwork() self.residual_dropout = nn.Dropout( p=self.config.model.model_params.dropout) def attention_mask(self, batch_size: int, seq_len: int) -> Tensor: attention_shape = (batch_size, seq_len, seq_len) attention_mask = np.triu(np.ones(attention_shape), k=1).astype("unit8") attention_mask = torch.from_numpy(attention_mask) == 0 return attention_mask # (batch_size, seq_len, seq_len) def forward( self, target_emb: Tensor, target_mask: Tensor, encoder_out: Tensor, encoder_mask: Optional[Tensor] = None, ) -> Tuple[Tensor, Tensor]: """ Args: target_emb: input to the decoder layer (batch_size, seq_len, dim_model) target_mask: padding mask of the target embedding encoder_out: the last encoder layer's output (batch_size, seq_len, dim_model) encoder_mask: boolean Tensor where padding elements are indicated by False (batch_size, seq_len) """ attention_mask = self.attention_mask(target_emb.size(0), target_emb.size(1)) target_emb = target_emb + self.masked_mha( query=target_emb, key=target_emb, value=target_emb, attention_mask=attention_mask, ) target_emb = self.ln(target_emb) target_emb = target_emb + self.mha( query=target_emb, key=encoder_out, value=encoder_out) target_emb = self.ln(target_emb) target_emb = target_emb + self.ffn(target_emb) return target_emb, target_mask
class EncoderLayer(nn.Module): """Encoder layer block""" def __init__(self, is_base: bool = True): super().__init__() self.config = Config() self.config.add_model(is_base) self.mha = MultiHeadAttention(masked_attention=False) self.attention_dropout = nn.Dropout( p=self.config.model.model_params.dropout) self.ln = LayerNorm(self.config.model.train_hparams.eps) self.ffn = FeedForwardNetwork() self.residual_dropout = nn.Dropout( p=self.config.model.model_params.dropout) def forward(self, source_emb: Tensor, source_mask: Tensor) -> Tuple[Tensor, Tensor]: source_emb = source_emb + self.mha( query=source_emb, key=source_emb, value=source_emb) source_emb = self.attention_dropout(source_emb) source_emb = self.ln(source_emb) source_emb = source_emb + self.residual_dropout(self.ffn(source_emb)) source_emb = self.ln(source_emb) return source_emb, source_mask
class WMT14DataLoader(LightningDataModule): """Load WMT14 dataset and prepare batches for training and testing transformer. Attributes: langpair: language pair to translate """ def __init__(self, langpair: str, is_base: bool = True) -> None: super().__init__() self.configs = Config() self.configs.add_data(langpair) self.configs.add_model(is_base) self.langpair = langpair self.max_length = self.configs.model.model_params.max_len def setup(self, stage: Optional[str] = None) -> None: """Assign dataset for use in dataloaders Args: stage: decide to load train/val or test """ # Assign train/val datasets for use in dataloaders if stage == "fit" or stage is None: self.train_dataset = WMT14Dataset(self.langpair, max_length=self.max_length, mode="train") self.valid_dataset = WMT14Dataset(self.langpair, max_length=self.max_length, mode="val") # Assign test dataset for use in dataloaders if stage == "test" or stage is None: self.test_dataset = WMT14Dataset(self.langpair, max_length=self.max_length, mode="test") def batch_by_tokens( self, dataset: Dataset, max_tokens: Optional[int] = None) -> List[torch.Tensor]: """Create mini-batch tensors by number of tokens Args: dataset: source and target dataset containing padded_token, mask, and length e.g., {'source': {'padded_token': torch.Tensor, 'mask': torch.Tensor, 'length': torch.Tensor}, 'target': {'padded_token': torch.Tensor, 'mask': torch.Tensor, 'length': torch.Tensor}} max_tokens: max number of tokens per batch Returns: indices_batches: """ max_tokens = (25000 if max_tokens is None else self.configs.model.train_hparams.batch_size) start_idx = 0 source_sample_lens, target_sample_lens = [], [] indices_batches = [] for end_idx in range(len(dataset)): source_sample_lens.append(dataset[end_idx]["source"]["length"]) target_sample_lens.append(dataset[end_idx]["target"]["length"]) # when batch is full if (sum(source_sample_lens) > max_tokens or sum(target_sample_lens) > max_tokens): indices_batch = torch.arange(start_idx, end_idx) indices_batches.append(indices_batch) start_idx = end_idx source_sample_lens, target_sample_lens = [ source_sample_lens[-1] ], [target_sample_lens[-1]] # end_idx is not included # when iteration ends elif end_idx == len(dataset): indices_batch = torch.arange(start_idx, end_idx) indices_batches.append(indices_batch) return indices_batches # TODO: batch together by approx. sequence length. def train_dataloader(self) -> DataLoader: batch_sampler = self.batch_by_tokens(self.train_dataset) return DataLoader( self.train_dataset, batch_sampler=batch_sampler, shuffle=False, drop_last=False, num_workers=self.configs.model.data_params.num_workers, ) def valid_dataloader(self) -> DataLoader: batch_sampler = self.batch_by_tokens(self.valid_dataset) return DataLoader( self.valid_dataset, batch_sampler=batch_sampler, shuffle=False, drop_last=False, num_workers=self.configs.model.data_params.num_workers, ) def test_dataloader(self) -> DataLoader: batch_sampler = self.batch_by_tokens(self.test_dataset) return DataLoader( self.test_dataset, batch_sampler=batch_sampler, shuffle=False, drop_last=False, num_workers=self.configs.model.data_params.num_workers, )