Example #1
0
    def __init__(self, special_symbols, enc_n_units, attn_type, n_heads,
                 n_layers, d_model, d_ff, pe_type, layer_norm_eps,
                 ffn_activation, vocab, tie_embedding, dropout, dropout_emb,
                 dropout_att, lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list,
                 backward, global_weight, mtl_per_batch, param_init):

        super(TransformerDecoder, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        self.prev_spk = ''
        self.lmstate_final = None

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = nn.Embedding(vocab, d_model, padding_idx=self.pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = repeat(
                TransformerDecoderBlock(d_model, d_ff, attn_type, n_heads,
                                        dropout, dropout_att, layer_norm_eps,
                                        ffn_activation, param_init), n_layers)
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.output = nn.Linear(d_model, vocab)
            if tie_embedding:
                self.output.weight = self.embed.weight

            if param_init == 'xavier_uniform':
                self.reset_parameters()
Example #2
0
    def __init__(self, special_symbols,
                 enc_n_units, attn_type, n_heads, n_layers,
                 d_model, d_ff, ffn_bottleneck_dim,
                 pe_type, layer_norm_eps, ffn_activation,
                 vocab, tie_embedding,
                 dropout, dropout_emb, dropout_att, dropout_layer, dropout_head,
                 lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list, backward,
                 global_weight, mtl_per_batch, param_init,
                 mma_chunk_size, mma_n_heads_mono, mma_n_heads_chunk,
                 mma_init_r, mma_eps, mma_std,
                 mma_no_denominator, mma_1dconv,
                 mma_quantity_loss_weight, mma_headdiv_loss_weight,
                 latency_metric, latency_loss_weight,
                 mma_first_layer, share_chunkwise_attention,
                 external_lm, lm_fusion):

        super(TransformerDecoder, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.att_weight = global_weight - ctc_weight
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.mtl_per_batch = mtl_per_batch

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.embed_cache = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # for MMA
        self.attn_type = attn_type
        self.quantity_loss_weight = mma_quantity_loss_weight
        self._quantity_loss_weight = mma_quantity_loss_weight  # for curriculum
        self.mma_first_layer = max(1, mma_first_layer)
        self.headdiv_loss_weight = mma_headdiv_loss_weight

        self.latency_metric = latency_metric
        self.latency_loss_weight = latency_loss_weight
        self.ctc_trigger = (self.latency_metric in ['ctc_sync'])
        if self.ctc_trigger:
            assert 0 < self.ctc_weight < 1

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1,
                           backward=backward)

        if self.att_weight > 0:
            # token embedding
            self.embed = nn.Embedding(self.vocab, d_model, padding_idx=self.pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type, param_init)
            # decoder
            self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderBlock(
                d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_layer,
                layer_norm_eps, ffn_activation, param_init,
                src_tgt_attention=False if lth < mma_first_layer - 1 else True,
                mma_chunk_size=mma_chunk_size,
                mma_n_heads_mono=mma_n_heads_mono,
                mma_n_heads_chunk=mma_n_heads_chunk,
                mma_init_r=mma_init_r,
                mma_eps=mma_eps,
                mma_std=mma_std,
                mma_no_denominator=mma_no_denominator,
                mma_1dconv=mma_1dconv,
                dropout_head=dropout_head,
                lm_fusion=lm_fusion,
                ffn_bottleneck_dim=ffn_bottleneck_dim,
                share_chunkwise_attention=share_chunkwise_attention)) for lth in range(n_layers)])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.output = nn.Linear(d_model, self.vocab)
            if tie_embedding:
                self.output.weight = self.embed.weight

            self.lm = external_lm
            if external_lm is not None:
                self.lm_output_proj = nn.Linear(external_lm.output_dim, d_model)

            self.reset_parameters(param_init)
Example #3
0
    def __init__(self,
                 special_symbols,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 dropout=0.,
                 dropout_emb=0.,
                 lsm_prob=0.,
                 ctc_weight=0.,
                 ctc_lsm_prob=0.,
                 ctc_fc_list=[],
                 lm_init=None,
                 global_weight=1.,
                 mtl_per_batch=False,
                 param_init=0.1):

        super(RNNTransducer, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm_transducer', 'gru_transducer']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.state_cache = OrderedDict()

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            # import warprnnt_pytorch
            # self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()

            # Prediction network
            rnn_l = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU
            self.rnn = nn.ModuleList()
            self.dropout = nn.Dropout(p=dropout)
            if n_projs > 0:
                self.proj = repeat(nn.Linear(n_units, n_projs), n_layers)
            dec_idim = emb_dim
            for l in range(n_layers):
                self.rnn += [rnn_l(dec_idim, n_units, 1, batch_first=True)]
                dec_idim = n_projs if n_projs > 0 else n_units

            self.embed = nn.Embedding(vocab, emb_dim, padding_idx=self.pad)
            self.dropout_emb = nn.Dropout(p=dropout_emb)

            # Joint network
            self.w_enc = nn.Linear(enc_n_units, bottleneck_dim)
            self.w_dec = nn.Linear(dec_idim, bottleneck_dim, bias=False)
            self.output = nn.Linear(bottleneck_dim, vocab)

        self.reset_parameters(param_init)

        # prediction network initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.n_projs == n_projs
            assert lm_init.n_layers == n_layers

            param_dict = dict(lm_init.named_parameters())
            for n, p in self.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)
Example #4
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 vocab,
                 tie_embedding=False,
                 pe_type='add',
                 layer_norm_eps=1e-12,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 focal_loss_weight=0.0,
                 focal_loss_gamma=2.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.focal_loss_weight = focal_loss_weight
        self.focal_loss_gamma = focal_loss_gamma
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = Embedding(
                vocab,
                d_model,
                dropout=0,  # NOTE: do not apply dropout here
                ignore_index=pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = nn.ModuleList([
                TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                        dropout, dropout_att, layer_norm_eps)
                for _ in range(n_layers)
            ])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model,
                    vocab,
                    cutoffs=[
                        round(self.vocab / 15), 3 * round(self.vocab / 15)
                    ],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = Linear(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
Example #5
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 attn_conv_kernel_size=0,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 lm_fusion=None,
                 lm_fusion_type='cold',
                 discourse_aware='',
                 lm_init=None,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 replace_sos=False,
                 soft_label_weight=0.0):

        super(CIFRNNDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm', 'gru']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.lm_fusion_type = lm_fusion_type
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch
        self.replace_sos = replace_sos
        self.soft_label_weight = soft_label_weight

        self.quantity_loss_weight = 1.0

        # for contextualization
        self.discourse_aware = discourse_aware
        self.dstate_prev = None

        # for cache
        self.prev_spk = ''
        self.total_step = 0
        self.dstates_final = None
        self.lmstate_final = None

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            # Attention layer
            self.score = CIF(enc_dim=self.enc_n_units,
                             conv_kernel_size=attn_conv_kernel_size,
                             conv_out_channels=self.enc_n_units)

            # Decoder
            self.rnn = nn.ModuleList()
            if self.n_projs > 0:
                self.proj = nn.ModuleList(
                    [Linear(n_units, n_projs) for _ in range(n_layers)])
            self.dropout = nn.ModuleList(
                [nn.Dropout(p=dropout) for _ in range(n_layers)])
            rnn = nn.LSTM if rnn_type == 'lstm' else nn.GRU
            dec_odim = enc_n_units + emb_dim
            for l in range(n_layers):
                self.rnn += [rnn(dec_odim, n_units, 1)]
                dec_odim = n_units
                if self.n_projs > 0:
                    dec_odim = n_projs

            # LM fusion
            if lm_fusion is not None:
                self.linear_dec_feat = Linear(dec_odim + enc_n_units, n_units)
                if lm_fusion_type in ['cold', 'deep']:
                    self.linear_lm_feat = Linear(lm_fusion.n_units, n_units)
                    self.linear_lm_gate = Linear(n_units * 2, n_units)
                elif lm_fusion_type == 'cold_prob':
                    self.linear_lm_feat = Linear(lm_fusion.vocab, n_units)
                    self.linear_lm_gate = Linear(n_units * 2, n_units)
                else:
                    raise ValueError(lm_fusion_type)
                self.output_bn = Linear(n_units * 2, bottleneck_dim)

                # fix LM parameters
                for p in lm_fusion.parameters():
                    p.requires_grad = False
            elif discourse_aware == 'hierarchical':
                raise NotImplementedError
            else:
                self.output_bn = Linear(dec_odim + enc_n_units, bottleneck_dim)

            self.embed = Embedding(vocab,
                                   emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.output = Linear(bottleneck_dim, vocab)
            # NOTE: include bias even when tying weights

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if tie_embedding:
                if emb_dim != bottleneck_dim:
                    raise ValueError(
                        'When using the tied flag, n_units must be equal to emb_dim.'
                    )
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters(param_init)

        # resister the external LM
        self.lm = lm_fusion

        # decoder initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.emb_dim == emb_dim
            logger.info('===== Initialize the decoder with pre-trained RNNLM')
            assert lm_init.n_projs == 0  # TODO(hirofumi): fix later
            assert lm_init.n_units_null_context == enc_n_units

            # RNN
            for l in range(lm_init.n_layers):
                for n, p in lm_init.rnn[l].named_parameters():
                    assert getattr(self.rnn[l], n).size() == p.size()
                    getattr(self.rnn[l], n).data = p.data
                    logger.info('Overwrite %s' % n)

            # embedding
            assert self.embed.embed.weight.size(
            ) == lm_init.embed.embed.weight.size()
            self.embed.embed.weight.data = lm_init.embed.embed.weight.data
            logger.info('Overwrite %s' % 'embed.embed.weight')
Example #6
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 rnn_type,
                 n_units,
                 n_projs,
                 n_layers,
                 residual,
                 bottleneck_dim,
                 emb_dim,
                 vocab,
                 tie_embedding=False,
                 dropout=0.0,
                 dropout_emb=0.0,
                 lsm_prob=0.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 lm_init=None,
                 lmobj_weight=0.0,
                 share_lm_softmax=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 param_init=0.1,
                 start_pointing=False,
                 end_pointing=True):

        super(RNNTransducer, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.vocab = vocab
        self.rnn_type = rnn_type
        assert rnn_type in ['lstm_transducer', 'gru_transducer']
        self.enc_n_units = enc_n_units
        self.dec_n_units = n_units
        self.n_projs = n_projs
        self.n_layers = n_layers
        self.residual = residual
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.lmobj_weight = lmobj_weight
        self.share_lm_softmax = share_lm_softmax
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        # VAD
        self.start_pointing = start_pointing
        self.end_pointing = end_pointing

        # for cache
        self.prev_spk = ''
        self.lmstate_final = None
        self.state_cache = OrderedDict()

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=param_init)

        if ctc_weight < global_weight:
            import warprnnt_pytorch
            self.warprnnt_loss = warprnnt_pytorch.RNNTLoss()

            # for MTL with LM objective
            if lmobj_weight > 0:
                if share_lm_softmax:
                    self.output_lmobj = self.output  # share paramters
                else:
                    self.output_lmobj = Linear(n_units, vocab)

            # Prediction network
            self.fast_impl = False
            rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU
            if n_projs == 0 and not residual:
                self.fast_impl = True
                self.rnn = rnn(emb_dim, n_units, n_layers,
                               bias=True,
                               batch_first=True,
                               dropout=dropout,
                               bidirectional=False)
                # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer
                dec_idim = n_units
                self.dropout_top = nn.Dropout(p=dropout)
            else:
                self.rnn = nn.ModuleList()
                self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(n_layers)])
                if n_projs > 0:
                    self.proj = nn.ModuleList([Linear(dec_idim, n_projs) for _ in range(n_layers)])
                dec_idim = emb_dim
                for l in range(n_layers):
                    self.rnn += [rnn(dec_idim, n_units, 1,
                                     bias=True,
                                     batch_first=True,
                                     dropout=0,
                                     bidirectional=False)]
                    dec_idim = n_projs if n_projs > 0 else n_units

            self.embed = Embedding(vocab, emb_dim,
                                   dropout=dropout_emb,
                                   ignore_index=pad)

            self.w_enc = Linear(enc_n_units, bottleneck_dim, bias=True)
            self.w_dec = Linear(dec_idim, bottleneck_dim, bias=False)
            self.output = Linear(bottleneck_dim, vocab)

        # Initialize parameters
        self.reset_parameters(param_init)

        # prediction network initialization with pre-trained LM
        if lm_init is not None:
            assert lm_init.vocab == vocab
            assert lm_init.n_units == n_units
            assert lm_init.n_projs == n_projs
            assert lm_init.n_layers == n_layers
            assert lm_init.residual == residual

            param_dict = dict(lm_init.named_parameters())
            for n, p in self.named_parameters():
                if n in param_dict.keys() and p.size() == param_dict[n].size():
                    if 'output' in n:
                        continue
                    p.data = param_dict[n].data
                    logger.info('Overwrite %s' % n)
Example #7
0
    def __init__(
            self, special_symbols, enc_n_units, attn_type, n_heads, n_layers,
            d_model, d_ff, pe_type, layer_norm_eps, ffn_activation, vocab,
            tie_embedding, dropout, dropout_emb, dropout_att, dropout_residual,
            lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list, backward,
            global_weight, mtl_per_batch, param_init, memory_transformer,
            mocha_chunk_size, mocha_n_heads_mono, mocha_n_heads_chunk,
            mocha_init_r, mocha_eps, mocha_std, mocha_quantity_loss_weight,
            mocha_head_divergence_loss_weight, latency_metric,
            latency_loss_weight, mocha_dropout_head, mocha_dropout_hard,
            mocha_first_layer, external_lm, lm_fusion, mem_len):

        super(TransformerDecoder, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        self.prev_spk = ''
        self.lmstate_final = None

        # for TransformerXL decoder
        self.memory_transformer = memory_transformer
        if memory_transformer:
            assert pe_type == 'none'
            self.mem_len = mem_len

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # for mocha
        self.attn_type = attn_type
        self.quantity_loss_weight = mocha_quantity_loss_weight
        self.headdiv_loss_weight = mocha_head_divergence_loss_weight
        self.latency_metric = latency_metric
        self.latency_loss_weight = latency_loss_weight
        self.mocha_first_layer = mocha_first_layer

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=self.vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1,
                           backward=backward)

        if ctc_weight < global_weight:
            # token embedding
            self.embed = nn.Embedding(self.vocab,
                                      d_model,
                                      padding_idx=self.pad)
            # positional embedding
            if memory_transformer:
                self.dropout_emb = nn.Dropout(p=dropout_emb)
                self.pos_emb = XLPositionalEmbedding(d_model)
                self.u = nn.Parameter(
                    torch.Tensor(self.n_heads, self.d_model // self.n_heads))
                self.v = nn.Parameter(
                    torch.Tensor(self.n_heads, self.d_model // self.n_heads))
                # NOTE: u and v are global parameters
            else:
                self.pos_enc = PositionalEncoding(d_model, dropout_emb,
                                                  pe_type, param_init)
            # self-attention
            self.layers = nn.ModuleList([
                copy.deepcopy(
                    TransformerDecoderBlock(
                        d_model,
                        d_ff,
                        attn_type,
                        n_heads,
                        dropout,
                        dropout_att,
                        dropout_residual * (l + 1) / n_layers,
                        layer_norm_eps,
                        ffn_activation,
                        param_init,
                        src_tgt_attention=False if 'mocha' in attn_type
                        and l < mocha_first_layer - 1 else True,
                        memory_transformer=memory_transformer,
                        mocha_chunk_size=mocha_chunk_size,
                        mocha_n_heads_mono=mocha_n_heads_mono,
                        mocha_n_heads_chunk=mocha_n_heads_chunk,
                        mocha_init_r=mocha_init_r,
                        mocha_eps=mocha_eps,
                        mocha_std=mocha_std,
                        mocha_dropout_head=mocha_dropout_head,
                        mocha_dropout_hard=mocha_dropout_hard *
                        (n_layers - l) / n_layers,  # the lower the stronger
                        lm_fusion=lm_fusion)) for l in range(n_layers)
            ])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.output = nn.Linear(d_model, self.vocab)
            if tie_embedding:
                self.output.weight = self.embed.weight

            self.lm = external_lm
            if external_lm is not None:
                self.lm_output_proj = nn.Linear(external_lm.output_dim,
                                                d_model)

            self.reset_parameters(param_init)