Esempio n. 1
0
class TransformerEncoderLayer(nn.Module):
    def __init__(
        self,
        n_heads,
        d_model,
        d_ff,
        slf_attn_dropout,
        ffn_dropout,
        residual_dropout,
        normalize_before=False,
        concat_after=False,
        relative_positional=False,
        activation="relu",
    ):
        super(TransformerEncoderLayer, self).__init__()

        self.relative_positional = relative_positional

        if self.relative_positional:
            self.slf_attn = MultiHeadedSelfAttentionWithRelPos(
                n_heads, d_model, slf_attn_dropout)
        else:
            self.slf_attn = MultiHeadedSelfAttention(n_heads, d_model,
                                                     slf_attn_dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, ffn_dropout,
                                                    activation)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(residual_dropout)
        self.dropout2 = nn.Dropout(residual_dropout)

        self.normalize_before = normalize_before
        self.concat_after = concat_after

        if self.concat_after:
            self.concat_linear = nn.Linear(d_model * 2, d_model)

    def forward(self, x, mask, pos=None):
        if self.normalize_before:
            x = self.norm1(x)
        residual = x

        if self.relative_positional:
            slf_attn_out, slf_attn_weights = self.slf_attn(x, mask, pos)
        else:
            slf_attn_out, slf_attn_weights = self.slf_attn(x, mask)

        if self.concat_after:
            x = residual + self.concat_linear(
                flow.cat([x, slf_attn_out], dim=-1))
        else:
            x = residual + self.dropout1(slf_attn_out)
        if not self.normalize_before:
            x = self.norm1(x)

        if self.normalize_before:
            x = self.norm2(x)
        residual = x
        x = residual + self.dropout2(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm2(x)

        return x, {"slf_attn_weights": slf_attn_weights}

    def inference(self, x, mask, pos=None, cache=None):
        if self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.relative_positional:
            slf_attn_out, slf_attn_weights, new_cache = self.slf_attn.inference(
                x, mask, cache, pos)
        else:
            slf_attn_out, slf_attn_weights, new_cache = self.slf_attn.inference(
                x, mask, cache)

        if self.concat_after:
            x = residual + self.concat_linear(
                flow.cat([x, slf_attn_out], dim=-1))
        else:
            x = residual + slf_attn_out
        if not self.normalize_before:
            x = self.norm1(x)

        if self.normalize_before:
            x = self.norm2(x)
        residual = x
        x = residual + self.feed_forward(x)
        if not self.normalize_before:
            x = self.norm2(x)

        return x, new_cache, {"slf_attn_weights": slf_attn_weights}
Esempio n. 2
0
class ConformerEncoderBlock(nn.Module):
    def __init__(self, d_model, d_ff, cov_kernel_size, n_heads, slf_attn_dropout=0.0, ffn_dropout=0.0,
                 residual_dropout=0.1, conv_dropout=0.0, macaron_style=True, conv_first=False,
                 ffn_scale=0.5, conv_bias=True, relative_positional=True, activation='glu'):
        super(ConformerEncoderBlock, self).__init__()

        self.conv_first = conv_first
        self.macaron_style = macaron_style
        self.ffn_scale = ffn_scale
        self.relative_positional = relative_positional
        self.residual_dropout = residual_dropout

        if self.macaron_style:
            self.pre_ffn = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation=activation)
            self.macaron_ffn_norm = nn.LayerNorm(d_model)

        if self.relative_positional:
            self.mha = MultiHeadedSelfAttentionWithRelPos(n_heads, d_model, slf_attn_dropout)
        else:
            self.mha = MultiHeadedSelfAttention(n_heads, d_model, slf_attn_dropout)
        self.mha_norm = nn.LayerNorm(d_model)

        self.conv = ConformerConvolutionModule(d_model, cov_kernel_size, conv_bias, conv_dropout)
        self.conv_norm = nn.LayerNorm(d_model)

        self.post_ffn = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation=activation)
        self.post_ffn_norm = nn.LayerNorm(d_model)

        self.final_norm = nn.LayerNorm(d_model)

    def pre_ffn_forward(self, x, dropout=0.0):
        residual = x
        x = self.macaron_ffn_norm(x)
        return residual + self.ffn_scale * F.dropout(self.pre_ffn(x), p=dropout)
    
    def pos_ffn_forward(self, x, dropout=0.0):
        residual = x
        x = self.post_ffn_norm(x)
        return residual + self.ffn_scale * F.dropout(self.post_ffn(x), p=dropout)

    def conv_augment_forward(self, x, mask, dropout=0.0):
        residual = x
        x = self.conv_norm(x)
        return residual + F.dropout(self.conv(x, mask), p=dropout)

    def attn_forward(self, x, mask, pos, dropout=0.0):
        residual = x
        x = self.mha_norm(x)
        if self.relative_positional:
            slf_attn_out, slf_attn_weights = self.mha(x, mask.unsqueeze(1), pos)
        else:
            slf_attn_out, slf_attn_weights = self.mha(x, mask.unsqueeze(1))
        slf_attn_out = residual + F.dropout(slf_attn_out, p=dropout)
        return slf_attn_out, slf_attn_weights

    def forward(self, x, mask, pos=None):

        if self.macaron_style:
            x = self.pre_ffn_forward(x, dropout=self.residual_dropout)

        if self.conv_first:
            x = self.conv_augment_forward(x, mask, dropout=self.residual_dropout)
            x, slf_attn_weights = self.attn_forward(x, mask, pos, dropout=self.residual_dropout)
        else:
            x, slf_attn_weights = self.attn_forward(x, mask, pos, dropout=self.residual_dropout)
            x = self.conv_augment_forward(x, mask, dropout=self.residual_dropout)

        x = self.post_ffn_norm(x)

        return self.final_norm(x), {'slf_attn_weights': slf_attn_weights}

    def attn_infer(self, x, mask, pos, cache):
        residual = x
        x = self.mha_norm(x)
        if self.relative_positional:
            slf_attn_out, slf_attn_weights, new_cache = self.mha.inference(x, mask.unsqueeze(1), pos, cache)
        else:
            slf_attn_out, slf_attn_weights, new_cache = self.mha.inference(x, mask.unsqueeze(1), cache)
        return residual + slf_attn_out, slf_attn_weights, new_cache

    def inference(self, x, mask, pos=None, cache=None):

        if self.macaron_style:
            x = self.pre_ffn_forward(x)

        if self.conv_first:
            x = self.conv_augment_forward(x, mask)
            x, slf_attn_weights, new_cache = self.attn_infer(x, mask, pos, cache)
        else:
            x, slf_attn_weights, new_cache = self.attn_infer(x, mask, pos, cache)
            x = self.conv_augment_forward(x, mask)

        x = self.post_ffn_forward(x)

        return self.final_norm(x),  new_cache, {'slf_attn_weights': slf_attn_weights}
Esempio n. 3
0
class TransformerDecoderLayer(nn.Module):
    def __init__(self, n_heads, d_model, d_ff, memory_dim, slf_attn_dropout=0.0, src_attn_dropout=0.0, ffn_dropout=0.0, residual_dropout=0.1,
                 normalize_before=False, concat_after=False, relative_positional=False, activation='relu'):
        super(TransformerDecoderLayer, self).__init__()

        self.relative_positional = relative_positional

        if self.relative_positional:
            self.slf_attn = MultiHeadedSelfAttentionWithRelPos(n_heads, d_model, slf_attn_dropout)
        else:
            self.slf_attn = MultiHeadedSelfAttention(n_heads, d_model, slf_attn_dropout)
        self.src_attn = MultiHeadedCrossAttention(n_heads, d_model, memory_dim, src_attn_dropout)
        self.feed_forward = PositionwiseFeedForward(d_model, d_ff, ffn_dropout, activation)

        self.norm1 = nn.LayerNorm(d_model)
        self.norm2 = nn.LayerNorm(d_model)
        self.norm3 = nn.LayerNorm(d_model)

        self.dropout1 = nn.Dropout(residual_dropout)
        self.dropout2 = nn.Dropout(residual_dropout)
        self.dropout3 = nn.Dropout(residual_dropout)

        self.normalize_before = normalize_before
        self.concat_after = concat_after

        if self.concat_after:
            self.concat_linear1 = nn.Linear(d_model * 2, d_model)
            self.concat_linear2 = nn.Linear(d_model * 2, d_model)

    def forward(self, tgt, tgt_mask, memory, memory_mask, pos):
        """Compute decoded features

        :param torch.Tensor tgt: decoded previous target features (batch, max_time_out, size)
        :param torch.Tensor tgt_mask: mask for x (batch, max_time_out)
        :param torch.Tensor memory: encoded source features (batch, max_time_in, size)
        :param torch.Tensor memory_mask: mask for memory (batch, max_time_in)
        """

        if self.normalize_before:
            tgt = self.norm1(tgt)
        residual = tgt

        if self.relative_positional:
            slf_attn_out, slf_attn_weights = self.slf_attn(tgt, tgt_mask, pos)
        else:
            slf_attn_out, slf_attn_weights = self.slf_attn(tgt, tgt_mask)

        if self.concat_after:
            x = residual + self.concat_linear1(torch.cat((tgt, slf_attn_out), dim=-1))
        else:
            x = residual + self.dropout1(slf_attn_out)
        if not self.normalize_before:
            x = self.norm1(x)

        if self.normalize_before:
            x = self.norm2(x)
        residual = x
        src_attn_out, src_attn_weights = self.src_attn(x, memory, memory_mask)
        if self.concat_after:
            x = residual + self.concat_linear2(torch.cat((x, src_attn_out), dim=-1))
        else:
            x = residual + self.dropout2(src_attn_out)
        if not self.normalize_before:
            x = self.norm2(x)

        if self.normalize_before:
            x = self.norm3(x)
        residual = x
        x = residual + self.dropout3(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm3(x)

        return x, {'slf_attn_weights': slf_attn_weights, 'src_attn_weights': src_attn_weights}

    def inference(self, x, xmask, memory, memory_mask=None, pos=None, cache={'slf': None, 'src': None}):

        if self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.relative_positional:
            slf_attn_out, slf_attn_weight, slf_cache = self.slf_attn.inference(x, xmask, pos, cache=['slf'])
        else:
            slf_attn_out, slf_attn_weight, slf_cache = self.slf_attn.inference(x, xmask, cache=['slf'])
        if self.concat_after:
            x = residual + self.concat_linear1(torch.cat((x, slf_attn_out), dim=-1))
        else:
            x = residual + self.dropout1(slf_attn_out)
        if not self.normalize_before:
            x = self.norm1(x)

        if self.normalize_before:
            x = self.norm2(x)
        residual = x
        src_attn_out, src_attn_weight, src_cache = self.src_attn.inference(x, memory, memory_mask, cache['src'])
        if self.concat_after:
            x = residual + self.concat_linear2(torch.cat((x, src_attn_out), dim=-1))
        else:
            x = residual + self.dropout2(src_attn_out)
        if not self.normalize_before:
            x = self.norm2(x)

        if self.normalize_before:
            x = self.norm3(x)
        residual = x
        x = residual + self.dropout3(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm3(x)

        return x, {'slf_attn_weight': slf_attn_weight, 'src_attn_weight': src_attn_weight}, {'slf': slf_cache, 'src': src_cache}