Ejemplo n.º 1
0
    def __init__(self, h, d_model, p, d_ff, attn_p=0.1, version=1.0):
        super(ParallelEncoderLayer, self).__init__()
        self.version = version

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  static=onmt.constants.static)
        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 static=onmt.constants.static)
        self.multihead = MultiHeadAttention(h,
                                            d_model,
                                            attn_p=attn_p,
                                            static=onmt.constants.static)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model, d_ff, ff_p)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        self.feedforward = Bottle(feedforward)
Ejemplo n.º 2
0
    def __init__(self, opt, embedding, language_embeddings=None, **kwargs):
        super(SpeechLSTMDecoder, self).__init__()

        # Keep for reference

        # Define layers
        self.model_size = opt.model_size
        self.layers = opt.layers
        self.dropout = opt.dropout

        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.variational_dropout = opt.variational_dropout

        self.encoder_type = opt.encoder_type

        self.lstm = nn.LSTM(self.model_size,
                            self.model_size,
                            self.layers,
                            dropout=self.dropout,
                            batch_first=True)

        self.fast_xattention = opt.fast_xattention
        self.n_head = 1  # fixed
        # also fix attention dropout to 0.0

        if opt.fast_xattention:
            self.multihead_tgt = EncdecMultiheadAttn(self.n_head,
                                                     opt.model_size, 0.0)
        else:
            self.multihead_tgt = MultiHeadAttention(self.n_head,
                                                    opt.model_size,
                                                    attn_p=0.0,
                                                    share=3)

        self.preprocess_layer = PrePostProcessing(
            self.model_size,
            self.emb_dropout,
            sequence='d',
            variational=self.variational_dropout)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')
        self.preprocess_attn = PrePostProcessing(self.model_size,
                                                 0,
                                                 sequence='n')

        self.word_lut = embedding

        self.encoder_cnn_downsampling = opt.cnn_downsampling
        self.language_embeddings = language_embeddings
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        if self.language_embedding_type == 'concat':
            self.projector = nn.Linear(opt.model_size * 2, opt.model_size)

        print("* Create LSTM Decoder with %d layers." % self.layers)
    def __init__(self, opt, death_rate=0.0, **kwargs):
        super().__init__()
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention

        self.preprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da',
                                                  variational=self.variational)
        self.preprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da',
                                                 variational=self.variational)
        d_head = opt.model_size // opt.n_heads
        self.adaptive_type = opt.adaptive
        self.factor_size = opt.layers

        # this model defaults as fast relative self attention
        if self.adaptive_type == 'universal':
            self.multihead = RelativeSelfMultiheadAttn(opt.model_size, opt.n_heads, opt.attn_dropout)

            self.feedforward = PositionWiseFeedForward(opt.model_size, opt.inner_size, opt.dropout,
                                                       variational=self.variational)
        else:
            self.multihead = AdaptiveRelativeAttn(opt.model_size, opt.n_heads, self.factor_size, opt.attn_dropout)
            self.feedforward = AdaptiveFeedForward(opt.model_size, opt.inner_size, self.factor_size,
                                                   opt.dropout, variational=self.variational)
Ejemplo n.º 4
0
    def add_layers(self, n_new_layer):

        self.new_modules = list()
        self.layers += n_new_layer

        for i in range(n_new_layer):
            layer = ParallelEncoderLayer(self.n_heads, self.model_size,
                                         self.dropout, self.inner_size,
                                         self.attn_dropout)

            # the first layer will use the preprocessing which is the last postprocessing
            if i == 0:
                layer.preprocess_attn.load_state_dict(
                    self.postprocess_layer.state_dict())
                #~ layer.preprocess_attn.layer_norm.function.weight.requires_grad = False
                #~ layer.preprocess_attn.layer_norm.function.bias.requires_grad = False
                #~ if hasattr(layer.postprocess_attn, 'k'):
                #~ layer.postprocess_attn.k.data.fill_(0.01)

                # replace the last postprocessing layer with a new one
                self.postprocess_layer = PrePostProcessing(self.model_size,
                                                           0,
                                                           sequence='n')

            self.layer_modules.append(layer)
Ejemplo n.º 5
0
    def __init__(
        self,
        h,
        d_model,
        p,
        d_ff,
        attn_p=0.1,
    ):
        super(LMDecoderLayer, self).__init__()

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  static=onmt.constants.static)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 static=onmt.constants.static)

        self.multihead_tgt = MultiHeadAttention(h,
                                                d_model,
                                                attn_p=attn_p,
                                                static=onmt.constants.static,
                                                share=1)

        ff_p = p
        feedforward = FeedForward(d_model,
                                  d_ff,
                                  ff_p,
                                  static=onmt.constants.static)
        self.feedforward = Bottle(feedforward)
Ejemplo n.º 6
0
    def __init__(self, h, d_model, p, d_ff, attn_p=0.1, version=1.0, ignore_source=False,
                 variational=False, death_rate=0.0):
        super(TransformerXLDecoderLayer, self).__init__()
        self.version = version
        self.ignore_source = ignore_source
        self.variational = variational
        self.death_rate = death_rate

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', variational=self.variational)

        d_head = d_model // h
        self.multihead_tgt = RelPartialLearnableMultiHeadAttn(h, d_model, d_head, dropatt=attn_p)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model, d_ff, ff_p, variational=self.variational)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        elif onmt.constants.activation_layer == 'linear_swish_linear':
            ff_p = p
            feedforward = FeedForwardSwish(d_model, d_ff, ff_p)
        else:
            raise NotImplementedError
        self.feedforward = Bottle(feedforward)
Ejemplo n.º 7
0
    def __init__(self,
                 opt,
                 embedding,
                 language_embeddings=None,
                 ignore_source=False,
                 allocate_positions=True):
        super(SpeechLSTMDecoder, self).__init__()

        # Keep for reference

        # Define layers
        self.model_size = opt.model_size
        self.layers = opt.layers
        self.dropout = opt.dropout

        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.variational_dropout = opt.variational_dropout

        self.encoder_type = opt.encoder_type

        self.lstm = nn.LSTM(self.model_size,
                            self.model_size,
                            self.layers,
                            dropout=self.dropout,
                            batch_first=True)

        self.fast_self_attention = opt.fast_self_attention

        if opt.fast_self_attention:
            self.multihead_tgt = SelfMultiheadAttn(opt.model_size, opt.n_heads,
                                                   opt.attn_dropout)
        else:
            self.multihead_tgt = MultiHeadAttention(opt.n_heads,
                                                    opt.model_size,
                                                    attn_p=opt.attn_dropout,
                                                    share=3)

        # self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d',
        #                                           variational=self.variational_dropout)
        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  0,
                                                  sequence='n')
        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.word_lut = embedding

        self.encoder_cnn_downsampling = opt.cnn_downsampling
        self.language_embeddings = language_embeddings
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        if self.language_embedding_type == 'concat':
            self.projector = nn.Linear(opt.model_size * 2, opt.model_size)
Ejemplo n.º 8
0
    def __init__(self, opt, embedding, language_embeddings=None, **kwargs):
        super(SpeechLSTMDecoder, self).__init__()

        # Keep for reference

        # Define layers
        self.model_size = opt.model_size
        self.layers = opt.layers
        self.dropout = opt.dropout

        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.variational_dropout = opt.variational_dropout
        self.multilingual_factorized_weights = opt.multilingual_factorized_weights
        self.mfw_rank = opt.mfw_rank
        self.encoder_type = opt.encoder_type
        self.n_languages = opt.n_languages

        self.lstm = nn.LSTM(self.model_size, self.model_size, self.layers, dropout=self.dropout, batch_first=True)
        if self.multilingual_factorized_weights:
            from onmt.modules.weight_control_lstm import WeightFactoredLSTM
            self.lstm = WeightFactoredLSTM(self.lstm, dropout=opt.weight_drop, n_languages=opt.n_languages,
                                           rank=self.mfw_rank)

        self.fast_xattention = opt.fast_xattention
        self.n_head = 1  # fixed to always use 1 head
        # also fix attention dropout to 0.0

        if self.multilingual_factorized_weights:
            self.fast_xattention = True
            from onmt.modules.multilingual_factorized.encdec_attention import MFWEncdecMultiheadAttn
            self.multihead_tgt = MFWEncdecMultiheadAttn(self.n_head, opt.model_size, 0.0, n_languages=opt.n_languages,
                                                        rank=opt.mfw_rank, weight_drop=0.0)
        else:
            if opt.fast_xattention:
                self.multihead_tgt = EncdecMultiheadAttn(self.n_head, opt.model_size, 0.0)
            else:
                self.multihead_tgt = MultiHeadAttention(self.n_head, opt.model_size, attn_p=0.0, share=3)

        self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d',
                                                  variational=self.variational_dropout)

        self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
        self.preprocess_attn = PrePostProcessing(self.model_size, 0, sequence='n')

        self.word_lut = embedding

        self.encoder_cnn_downsampling = opt.cnn_downsampling
        self.language_embeddings = language_embeddings
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        if self.language_embedding_type == 'concat':
            self.projector = nn.Linear(opt.model_size * 2, opt.model_size)

        print("* Create LSTM Decoder with %d layers." % self.layers)
Ejemplo n.º 9
0
    def __init__(self, opt, dicts, positional_encoder):

        super(ParallelTransformerDecoder, self).__init__()

        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time

        if hasattr(opt, 'grow_dropout'):
            self.grow_dropout = opt.grow_dropout

        if opt.time == 'positional_encoding':
            self.time_transformer = positional_encoder
        elif opt.time == 'gru':
            self.time_transformer = nn.GRU(self.model_size,
                                           self.model_size,
                                           1,
                                           batch_first=True)
        elif opt.time == 'lstm':
            self.time_transformer = nn.LSTM(self.model_size,
                                            self.model_size,
                                            1,
                                            batch_first=True)

        #~ self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=False)
        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  self.emb_dropout,
                                                  sequence='d',
                                                  static=onmt.constants.static)
        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.constants.PAD)

        self.positional_encoder = positional_encoder

        self.layer_modules = nn.ModuleList([
            DecoderLayer(self.n_heads, self.model_size, self.dropout,
                         self.inner_size, self.attn_dropout)
            for _ in range(self.layers)
        ])

        len_max = self.positional_encoder.len_max
        mask = torch.ByteTensor(
            np.triu(np.ones((len_max, len_max)), k=1).astype('uint8'))
        self.register_buffer('mask', mask)
Ejemplo n.º 10
0
    def __init__(self, opt):
        super(TacotronDecoder, self).__init__()
        self.n_mel_channels = opt.n_mel_channels
        self.n_frames_per_step = opt.n_frames_per_step
        self.encoder_embedding_dim = opt.model_size
        self.attention_rnn_dim = opt.model_size
        self.decoder_rnn_dim = opt.model_size
        self.prenet_dim = opt.prenet_dim
        self.max_decoder_steps = opt.max_decoder_steps
        self.gate_threshold = 0.5
        self.p_attention_dropout = opt.attn_dropout
        self.p_decoder_dropout = opt.dropout
        self.encoder_type = opt.encoder_type

        self.lstm = nn.LSTM(opt.prenet_dim, opt.model_size, 2, dropout=opt.dropout, batch_first=True)

        self.linear_trans = nn.Linear(opt.n_mel_channels * opt.n_frames_per_step , opt.model_size)
        torch.nn.init.xavier_uniform_(self.linear_trans.weight)

        if opt.fast_xattention:
            self.multihead_tgt =  EncdecMultiheadAttn(1, opt.model_size, opt.attn_dropout)
        else:
            self.multihead_tgt = MultiHeadAttention(1, opt.model_size, attn_p=opt.attn_dropout, share=3)

        self.preprocess_layer = PrePostProcessing(opt.model_size, 0, sequence='n')

        self.prenet = Prenet(
            opt.n_mel_channels * opt.n_frames_per_step,
            [opt.prenet_dim, opt.prenet_dim])

        self.attention_rnn = nn.LSTMCell(
            opt.prenet_dim + opt.model_size,
            opt.model_size)

        self.attention_layer = Attention(
            opt.model_size, opt.model_size,
            opt.attention_dim, opt.attention_location_n_filters,
            opt.attention_location_kernel_size)

        self.postprocess_layer = PrePostProcessing(opt.model_size, 0, sequence='n')

        self.decoder_rnn = nn.LSTMCell(
            opt.model_size + opt.model_size,
            opt.model_size, 1)

        self.linear_projection = LinearNorm(
            opt.model_size ,
            opt.n_mel_channels * opt.n_frames_per_step)

        self.gate_layer = LinearNorm(
            opt.model_size , 1,
            bias=True, w_init_gain='sigmoid')
Ejemplo n.º 11
0
    def __init__(self, opt, dicts, positional_encoder):

        super(ParallelTransformerEncoder, self).__init__()

        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time

        if hasattr(opt, 'grow_dropout'):
            self.grow_dropout = opt.grow_dropout

        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.constants.PAD)

        if opt.time == 'positional_encoding':
            self.time_transformer = positional_encoder
        elif opt.time == 'gru':
            self.time_transformer = nn.GRU(self.model_size,
                                           self.model_size,
                                           1,
                                           batch_first=True)
        elif opt.time == 'lstm':
            self.time_transformer = nn.LSTM(self.model_size,
                                            self.model_size,
                                            1,
                                            batch_first=True)

        #~ self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d', static=False)
        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  self.emb_dropout,
                                                  sequence='d',
                                                  static=onmt.constants.static)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.positional_encoder = positional_encoder

        self.layer_modules = nn.ModuleList([
            ParallelEncoderLayer(self.n_heads, self.model_size, self.dropout,
                                 self.inner_size, self.attn_dropout)
            for _ in range(self.layers)
        ])
    def __init__(self,
                 h,
                 d_model,
                 p,
                 d_ff,
                 attn_p=0.1,
                 variational=False,
                 death_rate=0.0,
                 max_len=64,
                 **kwargs):
        super(DistanceTransformerEncoderLayer, self).__init__()
        self.variational = variational
        self.death_rate = death_rate

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  variational=self.variational)
        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 variational=self.variational)
        # self.multihead = MultiHeadAttention(h, d_model, attn_p=attn_p, share=2)
        d_head = d_model // h
        self.multihead = LearnableRelMultiHeadAttn(h,
                                                   d_model,
                                                   d_head,
                                                   dropatt=attn_p,
                                                   max_len=max_len)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model,
                                      d_ff,
                                      ff_p,
                                      variational=self.variational)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        elif onmt.constants.activation_layer == 'linear_swish_linear':
            ff_p = p
            feedforward = FeedForwardSwish(d_model,
                                           d_ff,
                                           ff_p,
                                           variational=self.variational)
        else:
            raise NotImplementedError

        self.feedforward = Bottle(feedforward)
Ejemplo n.º 13
0
def preprocessing(rezero, model_size, post_norm=False):
    sequence = ''

    if not rezero and not post_norm:
        sequence += 'n'

    return PrePostProcessing(model_size, 0.0, sequence=sequence)
Ejemplo n.º 14
0
    def __init__(self, input_size, hidden_size, output_size, dropout=0.0):
        """
        :param input_size:
        :param hidden_size:
        :param bottleneck_size:
        :param output_size:
        :param n_hidden: number of hidden states between first hidden and the bottleneck
        """
        super().__init__()
        self.input_size = input_size
        self.dropout = dropout
        self.linear_in = nn.Linear(input_size, hidden_size)
        # self.hiddens = nn.ModuleList()
        # for i in range(n_hidden):
        #     self.hiddens.append(nn.Linear(hidden_size, hidden_size))

        self.bottleneck_in = nn.Linear(hidden_size, hidden_size)
        # self.bottleneck_out = nn.Linear(bottleneck_size, hidden_size)
        self.last_linear = nn.Linear(hidden_size, output_size)

        stdv = 1. / math.sqrt(self.last_linear.weight.size(1))

        torch.nn.init.uniform_(self.last_linear.weight, -stdv, stdv)
        self.last_linear.bias.data.zero_()

        self.postprocess_layer = PrePostProcessing(hidden_size,
                                                   0,
                                                   sequence='n')
Ejemplo n.º 15
0
    def __init__(self, opt, death_rate=0.0, **kwargs):
        super(RelativeTransformerEncoderLayer, self).__init__()
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)
        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='da',
                                                 variational=self.variational)
        d_head = opt.model_size // opt.n_heads
        if not self.fast_self_attention:
            self.multihead = RelPartialLearnableMultiHeadAttn(
                opt.n_heads, opt.model_size, d_head, dropatt=opt.attn_dropout)
        else:
            self.multihead = RelativeSelfMultiheadAttn(opt.model_size,
                                                       opt.n_heads,
                                                       opt.attn_dropout)

        print(opt.fast_feed_forward)
        if not opt.fast_feed_forward:
            feedforward = FeedForward(opt.model_size,
                                      opt.inner_size,
                                      opt.dropout,
                                      variational=self.variational)
            self.feedforward = Bottle(feedforward)
        else:
            self.feedforward = PositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational)
Ejemplo n.º 16
0
 def __init__(self, opt):
     super().__init__()
     self.layer_norm = PrePostProcessing(opt.model_size,
                                         opt.dropout,
                                         sequence='n')
     self.attn = MultiHeadAttention(opt.n_heads,
                                    opt.model_size,
                                    attn_p=opt.attn_dropout,
                                    share=2)
     self.dropout = opt.attn_dropout
     self.variational = opt.variational_dropout
Ejemplo n.º 17
0
 def __init__(self, h, d_model, p, d_ff, pos_encoder, time_encoder, attn_p=0.1, version=1.0):
     super(UniversalEncoderLayer, self).__init__()
     self.version = version
     # position and time embedding is added into the input before the layer
     self.pos_encoder = pos_encoder
     self.time_encoder = time_encoder
     
     self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
     self.postprocess_attn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static)
     self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
     self.postprocess_ffn = PrePostProcessing(d_model, p, sequence='da', static=onmt.constants.static)
     self.multihead = MultiHeadAttention(h, d_model, attn_p=attn_p, static=onmt.constants.static)
     
     if onmt.constants.activation_layer == 'linear_relu_linear':
         ff_p = p
         feedforward = FeedForward(d_model, d_ff, ff_p)
     elif onmt.constants.activation_layer == 'maxout':
         k = int(math.ceil(d_ff / d_model))
         feedforward = MaxOut(d_model, d_model, k)
     self.feedforward = Bottle(feedforward)
Ejemplo n.º 18
0
    def __init__(self, opt, dicts, positional_encoder, time_encoder):

        super(UniversalTransformerDecoder, self).__init__()

        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time

        self.positional_encoder = positional_encoder

        self.time_encoder = time_encoder

        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  self.emb_dropout,
                                                  sequence='d',
                                                  static=onmt.constants.static)
        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.constants.PAD)

        self.positional_encoder = positional_encoder

        self.recurrent_layer = UniversalDecoderLayer(
            self.n_heads, self.model_size, self.dropout, self.inner_size,
            self.positional_encoder, self.time_encoder, self.attn_dropout)

        len_max = self.positional_encoder.len_max
        mask = torch.ByteTensor(
            np.triu(np.ones((len_max, len_max)), k=1).astype('uint8'))
        self.register_buffer('mask', mask)
Ejemplo n.º 19
0
    def __init__(self, opt, death_rate=0.0):

        super(ConformerEncoderLayer, self).__init__()

        # FFN -> SelfAttention -> Conv -> FFN
        # PreNorm
        self.opt = opt
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.dropout = opt.dropout
        self.ffn_scale = 0.5

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)

        self.attn = RelativeSelfMultiheadAttn(opt.model_size, opt.n_heads,
                                              opt.attn_dropout)

        self.preprocess_mcr_ffn = PrePostProcessing(opt.model_size,
                                                    opt.dropout,
                                                    sequence='n')

        self.mcr_feedforward = PositionWiseFeedForward(
            opt.model_size,
            opt.inner_size,
            opt.dropout,
            variational=self.variational,
            activation='swish')

        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')

        self.feedforward = PositionWiseFeedForward(
            opt.model_size,
            opt.inner_size,
            opt.dropout,
            variational=self.variational,
            activation='swish')

        # there is batch norm inside convolution already
        # so no need for layer norm?
        self.preprocess_conv = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_conv = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)
        self.conv = ConformerConvBlock(opt.model_size,
                                       opt.conv_kernel,
                                       activation='swish')
Ejemplo n.º 20
0
    def __init__(self, opt, dicts, positional_encoder, time_encoder):

        super(UniversalTransformerEncoder, self).__init__()

        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time

        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.constants.PAD)

        self.positional_encoder = positional_encoder

        self.time_encoder = time_encoder

        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  self.emb_dropout,
                                                  sequence='d',
                                                  static=onmt.constants.static)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.positional_encoder = positional_encoder

        self.recurrent_layer = UniversalEncoderLayer(
            self.n_heads, self.model_size, self.dropout, self.inner_size,
            self.positional_encoder, self.time_encoder, self.attn_dropout)
Ejemplo n.º 21
0
    def __init__(self, opt, dicts):

        super().__init__()

        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.time = opt.time
        self.encoder_type = opt.encoder_type

        self.preprocess_layer = PrePostProcessing(self.model_size,
                                                  self.emb_dropout,
                                                  sequence='d',
                                                  static=False)

        self.word_lut = nn.Embedding(dicts.size(),
                                     self.model_size,
                                     padding_idx=onmt.constants.PAD)

        self.rnn = nn.LSTM(self.model_size,
                           self.model_size,
                           num_layers=3,
                           dropout=self.dropout)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   self.emb_dropout,
                                                   sequence='d',
                                                   static=False)

        self.h = None
        self.c = None
Ejemplo n.º 22
0
    def __init__(self, h, d_model, p, d_ff, attn_p=0.1):
        super(FCTEncoderLayer, self).__init__()

        self.preprocess_attn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_attn = PrePostProcessing(d_model,
                                                  p,
                                                  sequence='da',
                                                  static=True)
        #~ self.multihead = HierarchicalMultiHeadAttention(h, d_model, attn_p=attn_p)
        self.multihead = UniformMultiHeadAttention(h, d_model, attn_p=attn_p)

        self.preprocess_ffn = PrePostProcessing(d_model, p, sequence='n')
        self.postprocess_ffn = PrePostProcessing(d_model,
                                                 p,
                                                 sequence='da',
                                                 static=True)

        if onmt.constants.activation_layer == 'linear_relu_linear':
            ff_p = p
            feedforward = FeedForward(d_model, d_ff, ff_p)
        elif onmt.constants.activation_layer == 'maxout':
            k = int(math.ceil(d_ff / d_model))
            feedforward = MaxOut(d_model, d_model, k)
        self.feedforward = Bottle(feedforward)
Ejemplo n.º 23
0
def postprocessing(rezero,
                   model_size,
                   dropout,
                   variational=False,
                   post_norm=False):
    sequence = 'd'
    if rezero:
        sequence += 'z'
    else:
        sequence += 'a'
    if post_norm:
        sequence += 'n'

    return PrePostProcessing(model_size,
                             dropout,
                             sequence=sequence,
                             variational=variational)
    def __init__(self, opt, death_rate=0.0):
        super().__init__()
        self.ignore_source = opt.ignore_source
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention
        self.factor_size = opt.layers
        self.adaptive_type = opt.adaptive

        self.preprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
            self.postprocess_src_attn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da',
                                                          variational=self.variational)

            if self.adaptive_type == 'universal':
                self.multihead_src = EncdecMultiheadAttn(opt.n_heads, opt.model_size, opt.attn_dropout)
            else:
                self.multihead_src = AdaptiveEncDecAttn(opt.n_heads, opt.model_size, self.factor_size, opt.attn_dropout)

        self.preprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size, opt.dropout, sequence='da',
                                                 variational=self.variational)

        if self.adaptive_type == 'universal':
            self.multihead_tgt = RelativeSelfMultiheadAttn(opt.model_size, opt.n_heads, opt.attn_dropout)

            self.feedforward = PositionWiseFeedForward(opt.model_size, opt.inner_size, opt.dropout,
                                                  variational=self.variational)
        else:
            self.multihead_tgt = AdaptiveRelativeAttn(opt.model_size, opt.n_heads, self.factor_size,
                                                      opt.attn_dropout)
            self.feedforward = AdaptiveFeedForward(opt.model_size, opt.inner_size, self.factor_size,
                                                   opt.dropout, variational=self.variational)
Ejemplo n.º 25
0
    def __init__(self, opt, embedding, encoder_type='audio'):
        super(SpeechLSTMEncoder, self).__init__()
        self.opt = opt
        self.model_size = opt.model_size

        if hasattr(opt, 'encoder_layers') and opt.encoder_layers != -1:
            self.layers = opt.encoder_layers
        else:
            self.layers = opt.layers

        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout

        self.input_type = encoder_type
        self.cnn_downsampling = opt.cnn_downsampling

        self.switchout = 0.0  # for speech it has to be
        self.varitional_dropout = opt.variational_dropout
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        self.time = opt.time
        self.lsh_src_attention = opt.lsh_src_attention
        self.reversible = opt.src_reversible
        self.multilingual_factorized_weights = opt.multilingual_factorized_weights
        self.mfw_rank = opt.mfw_rank

        feature_size = opt.input_size
        self.channels = 1

        if opt.upsampling:
            feature_size = feature_size // 4

        if not self.cnn_downsampling:
            self.audio_trans = nn.Linear(feature_size, self.model_size)
            torch.nn.init.xavier_uniform_(self.audio_trans.weight)
        else:
            channels = self.channels
            cnn = [nn.Conv2d(channels, 32, kernel_size=(3, 3), stride=2), nn.ReLU(True), nn.BatchNorm2d(32),
                   nn.Conv2d(32, 32, kernel_size=(3, 3), stride=2), nn.ReLU(True), nn.BatchNorm2d(32)]

            feat_size = (((feature_size // channels) - 3) // 4) * 32
            # cnn.append()
            self.audio_trans = nn.Sequential(*cnn)
            self.linear_trans = nn.Linear(feat_size, self.model_size)

        self.unidirect = False

        self.rnn = nn.LSTM(input_size=self.model_size, hidden_size=self.model_size, num_layers=self.layers,
                           bidirectional=(not self.unidirect), bias=False, dropout=self.dropout, batch_first=True)

        if self.multilingual_factorized_weights:
            from onmt.modules.weight_control_lstm import WeightFactoredLSTM
            self.rnn = WeightFactoredLSTM(self.rnn, dropout=opt.weight_drop, n_languages=opt.n_languages,
                                          rank=self.mfw_rank)

        self.preprocess_layer = PrePostProcessing(self.model_size, self.emb_dropout, sequence='d',
                                                  variational=self.varitional_dropout)
        self.postprocess_layer = PrePostProcessing(self.model_size, 0, sequence='n')
Ejemplo n.º 26
0
    def __init__(self, opt, embedding, encoder_type='audio'):
        super(SpeechLSTMEncoder, self).__init__()
        self.opt = opt
        self.model_size = opt.model_size

        if hasattr(opt, 'encoder_layers') and opt.encoder_layers != -1:
            self.layers = opt.encoder_layers
        else:
            self.layers = opt.layers

        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout

        self.input_type = encoder_type
        self.cnn_downsampling = opt.cnn_downsampling

        self.switchout = opt.switchout
        self.varitional_dropout = opt.variational_dropout
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        self.time = opt.time
        self.lsh_src_attention = opt.lsh_src_attention
        self.reversible = opt.src_reversible

        if self.switchout > 0.0:
            self.word_dropout = 0.0

        feature_size = opt.input_size
        self.channels = 1

        if opt.upsampling:
            feature_size = feature_size // 4

        if not self.cnn_downsampling:
            self.audio_trans = nn.Linear(feature_size, self.model_size)
            torch.nn.init.xavier_uniform_(self.audio_trans.weight)
        else:
            channels = self.channels
            cnn = [
                nn.Conv2d(channels, 32, kernel_size=(3, 3), stride=2),
                nn.ReLU(True),
                nn.BatchNorm2d(32),
                nn.Conv2d(32, 32, kernel_size=(3, 3), stride=2),
                nn.ReLU(True),
                nn.BatchNorm2d(32)
            ]

            feat_size = (((feature_size // channels) - 3) // 4) * 32
            # cnn.append()
            self.audio_trans = nn.Sequential(*cnn)
            self.linear_trans = nn.Linear(feat_size, self.model_size)
            # assert self.model_size == feat_size, \
            #     "The model dimension doesn't match with the feature dim, expecting %d " % feat_size

        # if use_cnn:
        #     cnn = [nn.Conv2d(1, 32, kernel_size=(3, freq_kn), stride=(2, freq_std)),
        #            nn.Conv2d(32, 32, kernel_size=(3, freq_kn), stride=(2, freq_std))]
        #     self.cnn = nn.Sequential(*cnn)
        #     input_size = ((((input_size - freq_kn) // freq_std + 1) - freq_kn) // freq_std + 1) * 32
        # else:
        #     self.cnn = None

        self.unidirect = self.opt.unidirectional

        self.rnn = nn.LSTM(input_size=self.model_size,
                           hidden_size=self.model_size,
                           num_layers=self.layers,
                           bidirectional=(not self.unidirect),
                           bias=False,
                           dropout=self.dropout,
                           batch_first=True)

        self.rnn_1 = nn.LSTM(input_size=self.model_size,
                             hidden_size=self.model_size,
                             num_layers=self.layers,
                             bias=False,
                             dropout=self.dropout,
                             batch_first=True)

        self.rnn_2 = nn.LSTM(input_size=self.model_size,
                             hidden_size=self.model_size,
                             num_layers=self.layers,
                             bias=False,
                             dropout=self.dropout,
                             batch_first=True)
        self.preprocess_layer = PrePostProcessing(
            self.model_size,
            self.emb_dropout,
            sequence='d',
            variational=self.varitional_dropout)
        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')
Ejemplo n.º 27
0
    def __init__(self, opt, death_rate=0.0):
        super(RelativeTransformerDecoderLayer, self).__init__()
        self.ignore_source = opt.ignore_source
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.fast_self_attention = opt.fast_self_attention
        # self.lfv_multilingual = opt.lfv_multilingual

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(opt.model_size,
                                                         opt.dropout,
                                                         sequence='n')
            self.postprocess_src_attn = PrePostProcessing(
                opt.model_size,
                opt.dropout,
                sequence='da',
                variational=self.variational)

            if opt.fast_xattention:
                self.multihead_src = EncdecMultiheadAttn(
                    opt.n_heads, opt.model_size, opt.attn_dropout)
            else:
                self.multihead_src = MultiHeadAttention(
                    opt.n_heads,
                    opt.model_size,
                    attn_p=opt.attn_dropout,
                    share=2)

        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='da',
                                                 variational=self.variational)

        d_head = opt.model_size // opt.n_heads

        if not self.fast_self_attention:
            self.multihead_tgt = RelPartialLearnableMultiHeadAttn(
                opt.n_heads, opt.model_size, d_head, dropatt=opt.attn_dropout)
        else:
            self.multihead_tgt = RelativeSelfMultiheadAttn(
                opt.model_size, opt.n_heads, opt.attn_dropout)

        if not opt.fast_feed_forward:
            feedforward = FeedForward(opt.model_size,
                                      opt.inner_size,
                                      opt.dropout,
                                      variational=self.variational)
            self.feedforward = Bottle(feedforward)
        else:
            self.feedforward = PositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational)
Ejemplo n.º 28
0
    def __init__(self,
                 opt,
                 embedding,
                 positional_encoder,
                 encoder_type='text',
                 language_embeddings=None):

        super(TransformerEncoder, self).__init__()

        self.opt = opt
        self.model_size = opt.model_size
        self.n_heads = opt.n_heads
        self.inner_size = opt.inner_size
        if hasattr(opt, 'encoder_layers') and opt.encoder_layers != -1:
            self.layers = opt.encoder_layers
        else:
            self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout

        self.input_type = encoder_type
        self.cnn_downsampling = opt.cnn_downsampling
        self.death_rate = opt.death_rate

        self.switchout = opt.switchout
        self.varitional_dropout = opt.variational_dropout
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type

        self.time = opt.time
        self.lsh_src_attention = opt.lsh_src_attention
        self.reversible = opt.src_reversible

        # disable word dropout when switch out is in action
        if self.switchout > 0.0:
            self.word_dropout = 0.0

        feature_size = opt.input_size
        self.channels = 1  # n. audio channels

        if opt.upsampling:
            feature_size = feature_size // 4

        if encoder_type != "text":
            if not self.cnn_downsampling:
                self.audio_trans = nn.Linear(feature_size, self.model_size)
                torch.nn.init.xavier_uniform_(self.audio_trans.weight)
            else:
                channels = self.channels  # should be 1

                if not opt.no_batch_norm:
                    cnn = [
                        nn.Conv2d(channels, 32, kernel_size=(3, 3), stride=2),
                        nn.ReLU(True),
                        nn.BatchNorm2d(32),
                        nn.Conv2d(32, 32, kernel_size=(3, 3), stride=2),
                        nn.ReLU(True),
                        nn.BatchNorm2d(32)
                    ]
                else:
                    cnn = [
                        nn.Conv2d(channels, 32, kernel_size=(3, 3), stride=2),
                        nn.ReLU(True),
                        nn.Conv2d(32, 32, kernel_size=(3, 3), stride=2),
                        nn.ReLU(True)
                    ]

                feat_size = (((feature_size // channels) - 3) // 4) * 32
                self.audio_trans = nn.Sequential(*cnn)
                self.linear_trans = nn.Linear(feat_size, self.model_size)
                # assert self.model_size == feat_size, \
                #     "The model dimension doesn't match with the feature dim, expecting %d " % feat_size
        else:
            self.word_lut = embedding

        self.time_transformer = positional_encoder
        self.language_embedding = language_embeddings

        self.preprocess_layer = PrePostProcessing(
            self.model_size,
            self.emb_dropout,
            sequence='d',
            variational=self.varitional_dropout)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.positional_encoder = positional_encoder

        self.layer_modules = nn.ModuleList()
        self.build_modules()
Ejemplo n.º 29
0
    def __init__(self,
                 opt,
                 embedding,
                 positional_encoder,
                 language_embeddings=None,
                 ignore_source=False,
                 allocate_positions=True):
        """
        :param opt:
        :param embedding:
        :param positional_encoder:
        :param attribute_embeddings:
        :param ignore_source:
        """
        super(TransformerDecoder, self).__init__()
        opt.ignore_source = ignore_source
        self.opt = opt

        self.model_size = opt.model_size
        self.layers = opt.layers
        self.dropout = opt.dropout
        self.word_dropout = opt.word_dropout
        self.attn_dropout = opt.attn_dropout
        self.emb_dropout = opt.emb_dropout
        self.encoder_type = opt.encoder_type
        self.ignore_source = ignore_source
        self.encoder_cnn_downsampling = opt.cnn_downsampling
        self.variational_dropout = opt.variational_dropout
        self.switchout = opt.switchout
        self.death_rate = opt.death_rate
        self.time = opt.time
        self.use_language_embedding = opt.use_language_embedding
        self.language_embedding_type = opt.language_embedding_type
        self.reversible = opt.tgt_reversible

        if self.switchout > 0:
            self.word_dropout = 0

        self.time_transformer = positional_encoder

        self.preprocess_layer = PrePostProcessing(
            self.model_size,
            self.emb_dropout,
            sequence='d',
            variational=self.variational_dropout)

        self.postprocess_layer = PrePostProcessing(self.model_size,
                                                   0,
                                                   sequence='n')

        self.word_lut = embedding
        self.multi_embedding = not hasattr(self.word_lut, 'weight')

        # Using feature embeddings in models
        self.language_embeddings = language_embeddings

        if self.language_embedding_type == 'concat':
            self.projector = nn.Linear(opt.model_size * 2, opt.model_size)

        self.positional_encoder = positional_encoder

        if allocate_positions:
            if hasattr(self.positional_encoder, 'len_max'):
                len_max = self.positional_encoder.len_max
                mask = torch.ByteTensor(
                    np.triu(np.ones((len_max, len_max)), k=1).astype('uint8'))
                self.register_buffer('mask', mask)

        self.layer_modules = nn.ModuleList()
        self.build_modules()
Ejemplo n.º 30
0
    def __init__(self, opt, death_rate=0.0):
        super(RelativeTransformerDecoderLayer, self).__init__()
        self.ignore_source = opt.ignore_source
        self.variational = opt.variational_dropout
        self.death_rate = death_rate
        self.batch_ensemble = opt.batch_ensemble
        self.mfw = opt.multilingual_factorized_weights
        self.macaron = opt.macaron
        self.ffn_scale = 0.5 if self.macaron else 1
        self.dropout = opt.dropout

        if self.macaron:
            self.preprocess_mcr_ffn = PrePostProcessing(opt.model_size,
                                                        opt.dropout,
                                                        sequence='n')
            self.postprocess_mcr_ffn = PrePostProcessing(
                opt.model_size,
                opt.dropout,
                sequence='da',
                variational=self.variational)

            if self.mfw:
                self.mcr_feedforward = MFWPositionWiseFeedForward(
                    opt.model_size,
                    opt.inner_size,
                    opt.dropout,
                    variational=self.variational,
                    n_languages=opt.n_languages,
                    rank=opt.mfw_rank,
                    use_multiplicative=opt.mfw_multiplicative)
            else:
                self.mcr_feedforward = PositionWiseFeedForward(
                    opt.model_size,
                    opt.inner_size,
                    opt.dropout,
                    variational=self.variational)

        self.preprocess_attn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='n')
        self.postprocess_attn = PrePostProcessing(opt.model_size,
                                                  opt.dropout,
                                                  sequence='da',
                                                  variational=self.variational)

        if not self.ignore_source:
            self.preprocess_src_attn = PrePostProcessing(opt.model_size,
                                                         opt.dropout,
                                                         sequence='n')
            self.postprocess_src_attn = PrePostProcessing(
                opt.model_size,
                opt.dropout,
                sequence='da',
                variational=self.variational)
            # if self.batch_ensemble > 0:
            #     self.multihead_src = BEEncdecMultiheadAttn(opt.n_heads, opt.model_size, opt.attn_dropout,
            #                                                ensemble=self.batch_ensemble)
            # else:

            if not self.mfw:
                self.multihead_src = EncdecMultiheadAttn(
                    opt.n_heads, opt.model_size, opt.attn_dropout)
            else:
                self.multihead_src = MFWEncdecMultiheadAttn(
                    opt.n_heads,
                    opt.model_size,
                    opt.attn_dropout,
                    n_languages=opt.n_languages,
                    rank=opt.mfw_rank,
                    use_multiplicative=opt.mfw_multiplicative)

        self.preprocess_ffn = PrePostProcessing(opt.model_size,
                                                opt.dropout,
                                                sequence='n')
        self.postprocess_ffn = PrePostProcessing(opt.model_size,
                                                 opt.dropout,
                                                 sequence='da',
                                                 variational=self.variational)

        d_head = opt.model_size // opt.n_heads

        if self.mfw:
            self.feedforward = MFWPositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational,
                n_languages=opt.n_languages,
                rank=opt.mfw_rank,
                use_multiplicative=opt.mfw_multiplicative)

            self.multihead_tgt = MFWRelativeSelfMultiheadAttn(
                opt.model_size,
                opt.n_heads,
                opt.attn_dropout,
                n_languages=opt.n_languages,
                rank=opt.mfw_rank,
                use_multiplicative=opt.mfw_multiplicative)
        else:

            self.feedforward = PositionWiseFeedForward(
                opt.model_size,
                opt.inner_size,
                opt.dropout,
                variational=self.variational)

            self.multihead_tgt = RelativeSelfMultiheadAttn(
                opt.model_size, opt.n_heads, opt.attn_dropout)