class TransformerEncoderLayer(nn.Module): def __init__( self, n_heads, d_model, d_ff, slf_attn_dropout, ffn_dropout, residual_dropout, normalize_before=False, concat_after=False, relative_positional=False, activation="relu", ): super(TransformerEncoderLayer, self).__init__() self.relative_positional = relative_positional if self.relative_positional: self.slf_attn = MultiHeadedSelfAttentionWithRelPos( n_heads, d_model, slf_attn_dropout) else: self.slf_attn = MultiHeadedSelfAttention(n_heads, d_model, slf_attn_dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(residual_dropout) self.dropout2 = nn.Dropout(residual_dropout) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear = nn.Linear(d_model * 2, d_model) def forward(self, x, mask, pos=None): if self.normalize_before: x = self.norm1(x) residual = x if self.relative_positional: slf_attn_out, slf_attn_weights = self.slf_attn(x, mask, pos) else: slf_attn_out, slf_attn_weights = self.slf_attn(x, mask) if self.concat_after: x = residual + self.concat_linear( flow.cat([x, slf_attn_out], dim=-1)) else: x = residual + self.dropout1(slf_attn_out) if not self.normalize_before: x = self.norm1(x) if self.normalize_before: x = self.norm2(x) residual = x x = residual + self.dropout2(self.feed_forward(x)) if not self.normalize_before: x = self.norm2(x) return x, {"slf_attn_weights": slf_attn_weights} def inference(self, x, mask, pos=None, cache=None): if self.normalize_before: x = self.norm1(x) residual = x if self.relative_positional: slf_attn_out, slf_attn_weights, new_cache = self.slf_attn.inference( x, mask, cache, pos) else: slf_attn_out, slf_attn_weights, new_cache = self.slf_attn.inference( x, mask, cache) if self.concat_after: x = residual + self.concat_linear( flow.cat([x, slf_attn_out], dim=-1)) else: x = residual + slf_attn_out if not self.normalize_before: x = self.norm1(x) if self.normalize_before: x = self.norm2(x) residual = x x = residual + self.feed_forward(x) if not self.normalize_before: x = self.norm2(x) return x, new_cache, {"slf_attn_weights": slf_attn_weights}
class ConformerEncoderBlock(nn.Module): def __init__(self, d_model, d_ff, cov_kernel_size, n_heads, slf_attn_dropout=0.0, ffn_dropout=0.0, residual_dropout=0.1, conv_dropout=0.0, macaron_style=True, conv_first=False, ffn_scale=0.5, conv_bias=True, relative_positional=True, activation='glu'): super(ConformerEncoderBlock, self).__init__() self.conv_first = conv_first self.macaron_style = macaron_style self.ffn_scale = ffn_scale self.relative_positional = relative_positional self.residual_dropout = residual_dropout if self.macaron_style: self.pre_ffn = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation=activation) self.macaron_ffn_norm = nn.LayerNorm(d_model) if self.relative_positional: self.mha = MultiHeadedSelfAttentionWithRelPos(n_heads, d_model, slf_attn_dropout) else: self.mha = MultiHeadedSelfAttention(n_heads, d_model, slf_attn_dropout) self.mha_norm = nn.LayerNorm(d_model) self.conv = ConformerConvolutionModule(d_model, cov_kernel_size, conv_bias, conv_dropout) self.conv_norm = nn.LayerNorm(d_model) self.post_ffn = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation=activation) self.post_ffn_norm = nn.LayerNorm(d_model) self.final_norm = nn.LayerNorm(d_model) def pre_ffn_forward(self, x, dropout=0.0): residual = x x = self.macaron_ffn_norm(x) return residual + self.ffn_scale * F.dropout(self.pre_ffn(x), p=dropout) def pos_ffn_forward(self, x, dropout=0.0): residual = x x = self.post_ffn_norm(x) return residual + self.ffn_scale * F.dropout(self.post_ffn(x), p=dropout) def conv_augment_forward(self, x, mask, dropout=0.0): residual = x x = self.conv_norm(x) return residual + F.dropout(self.conv(x, mask), p=dropout) def attn_forward(self, x, mask, pos, dropout=0.0): residual = x x = self.mha_norm(x) if self.relative_positional: slf_attn_out, slf_attn_weights = self.mha(x, mask.unsqueeze(1), pos) else: slf_attn_out, slf_attn_weights = self.mha(x, mask.unsqueeze(1)) slf_attn_out = residual + F.dropout(slf_attn_out, p=dropout) return slf_attn_out, slf_attn_weights def forward(self, x, mask, pos=None): if self.macaron_style: x = self.pre_ffn_forward(x, dropout=self.residual_dropout) if self.conv_first: x = self.conv_augment_forward(x, mask, dropout=self.residual_dropout) x, slf_attn_weights = self.attn_forward(x, mask, pos, dropout=self.residual_dropout) else: x, slf_attn_weights = self.attn_forward(x, mask, pos, dropout=self.residual_dropout) x = self.conv_augment_forward(x, mask, dropout=self.residual_dropout) x = self.post_ffn_norm(x) return self.final_norm(x), {'slf_attn_weights': slf_attn_weights} def attn_infer(self, x, mask, pos, cache): residual = x x = self.mha_norm(x) if self.relative_positional: slf_attn_out, slf_attn_weights, new_cache = self.mha.inference(x, mask.unsqueeze(1), pos, cache) else: slf_attn_out, slf_attn_weights, new_cache = self.mha.inference(x, mask.unsqueeze(1), cache) return residual + slf_attn_out, slf_attn_weights, new_cache def inference(self, x, mask, pos=None, cache=None): if self.macaron_style: x = self.pre_ffn_forward(x) if self.conv_first: x = self.conv_augment_forward(x, mask) x, slf_attn_weights, new_cache = self.attn_infer(x, mask, pos, cache) else: x, slf_attn_weights, new_cache = self.attn_infer(x, mask, pos, cache) x = self.conv_augment_forward(x, mask) x = self.post_ffn_forward(x) return self.final_norm(x), new_cache, {'slf_attn_weights': slf_attn_weights}
class TransformerDecoderLayer(nn.Module): def __init__(self, n_heads, d_model, d_ff, memory_dim, slf_attn_dropout=0.0, src_attn_dropout=0.0, ffn_dropout=0.0, residual_dropout=0.1, normalize_before=False, concat_after=False, relative_positional=False, activation='relu'): super(TransformerDecoderLayer, self).__init__() self.relative_positional = relative_positional if self.relative_positional: self.slf_attn = MultiHeadedSelfAttentionWithRelPos(n_heads, d_model, slf_attn_dropout) else: self.slf_attn = MultiHeadedSelfAttention(n_heads, d_model, slf_attn_dropout) self.src_attn = MultiHeadedCrossAttention(n_heads, d_model, memory_dim, src_attn_dropout) self.feed_forward = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation) self.norm1 = nn.LayerNorm(d_model) self.norm2 = nn.LayerNorm(d_model) self.norm3 = nn.LayerNorm(d_model) self.dropout1 = nn.Dropout(residual_dropout) self.dropout2 = nn.Dropout(residual_dropout) self.dropout3 = nn.Dropout(residual_dropout) self.normalize_before = normalize_before self.concat_after = concat_after if self.concat_after: self.concat_linear1 = nn.Linear(d_model * 2, d_model) self.concat_linear2 = nn.Linear(d_model * 2, d_model) def forward(self, tgt, tgt_mask, memory, memory_mask, pos): """Compute decoded features :param torch.Tensor tgt: decoded previous target features (batch, max_time_out, size) :param torch.Tensor tgt_mask: mask for x (batch, max_time_out) :param torch.Tensor memory: encoded source features (batch, max_time_in, size) :param torch.Tensor memory_mask: mask for memory (batch, max_time_in) """ if self.normalize_before: tgt = self.norm1(tgt) residual = tgt if self.relative_positional: slf_attn_out, slf_attn_weights = self.slf_attn(tgt, tgt_mask, pos) else: slf_attn_out, slf_attn_weights = self.slf_attn(tgt, tgt_mask) if self.concat_after: x = residual + self.concat_linear1(torch.cat((tgt, slf_attn_out), dim=-1)) else: x = residual + self.dropout1(slf_attn_out) if not self.normalize_before: x = self.norm1(x) if self.normalize_before: x = self.norm2(x) residual = x src_attn_out, src_attn_weights = self.src_attn(x, memory, memory_mask) if self.concat_after: x = residual + self.concat_linear2(torch.cat((x, src_attn_out), dim=-1)) else: x = residual + self.dropout2(src_attn_out) if not self.normalize_before: x = self.norm2(x) if self.normalize_before: x = self.norm3(x) residual = x x = residual + self.dropout3(self.feed_forward(x)) if not self.normalize_before: x = self.norm3(x) return x, {'slf_attn_weights': slf_attn_weights, 'src_attn_weights': src_attn_weights} def inference(self, x, xmask, memory, memory_mask=None, pos=None, cache={'slf': None, 'src': None}): if self.normalize_before: x = self.norm1(x) residual = x if self.relative_positional: slf_attn_out, slf_attn_weight, slf_cache = self.slf_attn.inference(x, xmask, pos, cache=['slf']) else: slf_attn_out, slf_attn_weight, slf_cache = self.slf_attn.inference(x, xmask, cache=['slf']) if self.concat_after: x = residual + self.concat_linear1(torch.cat((x, slf_attn_out), dim=-1)) else: x = residual + self.dropout1(slf_attn_out) if not self.normalize_before: x = self.norm1(x) if self.normalize_before: x = self.norm2(x) residual = x src_attn_out, src_attn_weight, src_cache = self.src_attn.inference(x, memory, memory_mask, cache['src']) if self.concat_after: x = residual + self.concat_linear2(torch.cat((x, src_attn_out), dim=-1)) else: x = residual + self.dropout2(src_attn_out) if not self.normalize_before: x = self.norm2(x) if self.normalize_before: x = self.norm3(x) residual = x x = residual + self.dropout3(self.feed_forward(x)) if not self.normalize_before: x = self.norm3(x) return x, {'slf_attn_weight': slf_attn_weight, 'src_attn_weight': src_attn_weight}, {'slf': slf_cache, 'src': src_cache}