Beispiel #1
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.d_model = args.transformer_d_model
        self.n_layers = args.n_layers
        self.n_heads = args.transformer_n_heads
        self.lsm_prob = args.lsm_prob

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = nn.Embedding(self.vocab,
                                  self.d_model,
                                  padding_idx=self.pad)
        self.pos_enc = PositionalEncoding(self.d_model, args.dropout_in,
                                          args.transformer_pe_type)
        self.layers = repeat(
            TransformerDecoderBlock(self.d_model,
                                    args.transformer_d_ff,
                                    args.transformer_attn_type,
                                    self.n_heads,
                                    args.dropout_hidden,
                                    args.dropout_att,
                                    args.transformer_layer_norm_eps,
                                    args.transformer_ffn_activation,
                                    src_tgt_attention=False), self.n_layers)
        self.norm_out = nn.LayerNorm(self.d_model,
                                     eps=args.transformer_layer_norm_eps)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                self.d_model,
                self.vocab,
                cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = nn.Linear(self.d_model, self.vocab)
            if args.tie_embedding:
                self.output.weight = self.embed.weight

        self.reset_parameters()
Beispiel #2
0
    def __init__(self, special_symbols, enc_n_units, attn_type, n_heads,
                 n_layers, d_model, d_ff, pe_type, layer_norm_eps,
                 ffn_activation, vocab, tie_embedding, dropout, dropout_emb,
                 dropout_att, lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list,
                 backward, global_weight, mtl_per_batch, param_init):

        super(TransformerDecoder, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        self.prev_spk = ''
        self.lmstate_final = None

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = nn.Embedding(vocab, d_model, padding_idx=self.pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = repeat(
                TransformerDecoderBlock(d_model, d_ff, attn_type, n_heads,
                                        dropout, dropout_att, layer_norm_eps,
                                        ffn_activation, param_init), n_layers)
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.output = nn.Linear(d_model, vocab)
            if tie_embedding:
                self.output.weight = self.embed.weight

            if param_init == 'xavier_uniform':
                self.reset_parameters()
Beispiel #3
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 vocab,
                 tie_embedding=False,
                 pe_type='add',
                 layer_norm_eps=1e-12,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 focal_loss_weight=0.0,
                 focal_loss_gamma=2.0,
                 ctc_weight=0.0,
                 ctc_lsm_prob=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()
        logger = logging.getLogger('training')

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.focal_loss_weight = focal_loss_weight
        self.focal_loss_gamma = focal_loss_gamma
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            self.ctc = CTC(eos=eos,
                           blank=blank,
                           enc_n_units=enc_n_units,
                           vocab=vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1)

        if ctc_weight < global_weight:
            self.embed = Embedding(
                vocab,
                d_model,
                dropout=0,  # NOTE: do not apply dropout here
                ignore_index=pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type)
            self.layers = nn.ModuleList([
                TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                        dropout, dropout_att, layer_norm_eps)
                for _ in range(n_layers)
            ])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model,
                    vocab,
                    cutoffs=[
                        round(self.vocab / 15), 3 * round(self.vocab / 15)
                    ],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = Linear(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
Beispiel #4
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger = logging.getLogger('training')
        logger.info(self.__class__.__name__)

        self.save_path = save_path

        self.d_model = args.d_model
        self.d_ff = args.d_ff
        self.pe_type = args.pe_type
        self.n_layers = args.n_layers
        self.n_heads = args.attn_n_heads
        self.tie_embedding = args.tie_embedding

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # self.lsm_prob = lsm_prob

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []

        self.embed = Embedding(
            vocab=self.vocab,
            emb_dim=self.d_model,
            dropout=0,  # NOTE: do not apply dropout here
            ignore_index=self.pad)
        self.pos_enc = PositionalEncoding(args.d_model, args.dropout_emb,
                                          args.pe_type)

        self.layers = nn.ModuleList([
            TransformerDecoderBlock(args.d_model,
                                    args.d_ff,
                                    args.attn_type,
                                    args.attn_n_heads,
                                    args.dropout_hidden,
                                    args.dropout_att,
                                    args.layer_norm_eps,
                                    src_attention=False)
            for _ in range(self.n_layers)
        ])
        self.norm_out = nn.LayerNorm(args.d_model, eps=args.layer_norm_eps)

        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                args.d_model,
                self.vocab,
                cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                div_value=4.0)
            self.output = None
        else:
            self.adaptive_softmax = None
            self.output = LinearND(self.d_model, self.vocab)

            # Optionally tie weights as in:
            # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
            # https://arxiv.org/abs/1608.05859
            # and
            # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
            # https://arxiv.org/abs/1611.01462
            if args.tie_embedding:
                self.output.fc.weight = self.embed.embed.weight

        # Initialize parameters
        self.reset_parameters()
Beispiel #5
0
    def __init__(self,
                 eos,
                 unk,
                 pad,
                 blank,
                 enc_n_units,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type,
                 tie_embedding,
                 vocab,
                 dropout=0.0,
                 dropout_emb=0.0,
                 dropout_att=0.0,
                 lsm_prob=0.0,
                 layer_norm_eps=1e-6,
                 ctc_weight=0.0,
                 ctc_fc_list=[],
                 backward=False,
                 global_weight=1.0,
                 mtl_per_batch=False,
                 adaptive_softmax=False):

        super(TransformerDecoder, self).__init__()

        self.eos = eos
        self.unk = unk
        self.pad = pad
        self.blank = blank
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.ctc_weight = ctc_weight
        self.ctc_fc_list = ctc_fc_list
        self.backward = backward
        self.global_weight = global_weight
        self.mtl_per_batch = mtl_per_batch

        if ctc_weight > 0:
            # Fully-connected layers for CTC
            if len(ctc_fc_list) > 0:
                fc_layers = OrderedDict()
                for i in range(len(ctc_fc_list)):
                    input_dim = d_model if i == 0 else ctc_fc_list[i - 1]
                    fc_layers['fc' + str(i)] = LinearND(input_dim, ctc_fc_list[i], dropout=dropout)
                fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(ctc_fc_list[-1], vocab, dropout=0)
                self.output_ctc = nn.Sequential(fc_layers)
            else:
                self.output_ctc = LinearND(d_model, vocab)
            self.decode_ctc_greedy = GreedyDecoder(blank=blank)
            self.decode_ctc_beam = BeamSearchDecoder(blank=blank)
            import warpctc_pytorch
            self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True)

        if ctc_weight < global_weight:
            self.layers = nn.ModuleList(
                [TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                         dropout, dropout_att, layer_norm_eps)
                 for _ in range(n_layers)])

            self.embed = Embedding(vocab, d_model,
                                   dropout=0,  # NOTE: do not apply dropout here
                                   ignore_index=pad)
            if pe_type:
                self.pos_emb_out = PositionalEncoding(d_model, dropout_emb, pe_type)

            if adaptive_softmax:
                self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                    d_model, vocab,
                    cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                    # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                    div_value=4.0)
                self.output = None
            else:
                self.adaptive_softmax = None
                self.output = LinearND(d_model, vocab)

                # Optionally tie weights as in:
                # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
                # https://arxiv.org/abs/1608.05859
                # and
                # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
                # https://arxiv.org/abs/1611.01462
                if tie_embedding:
                    self.output.fc.weight = self.embed.embed.weight

            self.norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)

        # Initialize parameters
        self.reset_parameters()
Beispiel #6
0
    def __init__(self,
                 input_dim,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type='add',
                 layer_norm_eps=1e-6,
                 dropout_in=0,
                 dropout=0,
                 dropout_att=0,
                 last_proj_dim=0,
                 n_stacks=1,
                 n_splices=1,
                 conv_in_channel=1,
                 conv_channels=0,
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 conv_poolings=[],
                 conv_batch_norm=False,
                 conv_residual=False,
                 conv_bottleneck_dim=0,
                 param_init=0.1):

        super(TransformerEncoder, self).__init__()
        logger = logging.getLogger("training")

        self.d_model = d_model
        self.n_layers = n_layers
        self.attn_n_heads = attn_n_heads
        self.pe_type = pe_type

        # Setting for CNNs before RNNs
        if conv_channels:
            channels = [int(c) for c in conv_channels.split('_')
                        ] if len(conv_channels) > 0 else []
            kernel_sizes = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_kernel_sizes.split('_')
                            ] if len(conv_kernel_sizes) > 0 else []
            strides = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_strides.split('_')
                       ] if len(conv_strides) > 0 else []
            poolings = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_poolings.split('_')
                        ] if len(conv_poolings) > 0 else []
        else:
            channels = []
            kernel_sizes = []
            strides = []
            poolings = []
            logger.warning(
                'Subsampling is automatically ignored because CNN layers are used before RNN layers.'
            )

        if len(channels) > 0:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=channels,
                                    kernel_sizes=kernel_sizes,
                                    strides=strides,
                                    poolings=poolings,
                                    dropout=0,
                                    batch_norm=conv_batch_norm,
                                    residual=conv_residual,
                                    bottleneck_dim=d_model,
                                    param_init=param_init)
            self._output_dim = self.conv.output_dim
        else:
            self._output_dim = input_dim * n_splices * n_stacks
            self.conv = None

            self.embed = LinearND(self._output_dim, d_model,
                                  dropout=0)  # NOTE: do not apply dropout here

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                    dropout, dropout_att, layer_norm_eps)
            for l in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        if last_proj_dim != self.output_dim:
            self.bridge = LinearND(self._output_dim,
                                   last_proj_dim,
                                   dropout=dropout)
            self._output_dim = last_proj_dim
        else:
            self.bridge = None
            self._output_dim = d_model

        # Initialize parameters
        self.reset_parameters()
Beispiel #7
0
    def __init__(self, input_dim, enc_type, attn_type, n_heads, n_layers,
                 n_layers_sub1, n_layers_sub2, d_model, d_ff, last_proj_dim,
                 pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout,
                 dropout_att, dropout_residual, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes,
                 conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm,
                 conv_bottleneck_dim, conv_param_init, task_specific_layer,
                 param_init, chunk_size_left, chunk_size_current,
                 chunk_size_right):

        super(TransformerEncoder, self).__init__()

        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type

        # for latency-controlled
        self.chunk_size_left = chunk_size_left
        self.chunk_size_cur = chunk_size_current
        self.chunk_size_right = chunk_size_right

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs before RNNs
        if conv_channels:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads,
                                        dropout, dropout_att,
                                        dropout_residual * (l + 1) / n_layers,
                                        layer_norm_eps, ffn_activation,
                                        param_init)) for l in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, attn_type, n_heads, dropout, dropout_att,
                    dropout_residual * n_layers_sub1 / n_layers,
                    layer_norm_eps, ffn_activation, param_init)
            self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, attn_type, n_heads, dropout, dropout_att,
                    dropout_residual * n_layers_sub2 / n_layers,
                    layer_norm_eps, ffn_activation, param_init)
            self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)

        if last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor()

        if param_init == 'xavier_uniform':
            self.reset_parameters()
Beispiel #8
0
    def __init__(self,
                 input_dim,
                 attn_type,
                 attn_n_heads,
                 n_layers,
                 d_model,
                 d_ff,
                 pe_type,
                 dropout_in=0,
                 dropout=0,
                 dropout_att=0,
                 layer_norm_eps=1e-6,
                 n_stacks=1,
                 n_splices=1,
                 conv_in_channel=1,
                 conv_channels=0,
                 conv_kernel_sizes=[],
                 conv_strides=[],
                 conv_poolings=[],
                 conv_batch_norm=False,
                 conv_residual=False,
                 conv_bottleneck_dim=0):

        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.pe_type = pe_type

        # Setting for CNNs before RNNs
        if conv_channels:
            channels = [int(c) for c in conv_channels.split('_')
                        ] if len(conv_channels) > 0 else []
            kernel_sizes = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_kernel_sizes.split('_')
                            ] if len(conv_kernel_sizes) > 0 else []
            strides = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_strides.split('_')
                       ] if len(conv_strides) > 0 else []
            poolings = [[
                int(c.split(',')[0].replace('(', '')),
                int(c.split(',')[1].replace(')', ''))
            ] for c in conv_poolings.split('_')
                        ] if len(conv_poolings) > 0 else []
        else:
            channels = []
            kernel_sizes = []
            strides = []
            poolings = []

        if len(channels) > 0:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim * n_stacks,
                                    in_channel=conv_in_channel,
                                    channels=channels,
                                    kernel_sizes=kernel_sizes,
                                    strides=strides,
                                    poolings=poolings,
                                    dropout=0,
                                    batch_norm=conv_batch_norm,
                                    residual=conv_residual,
                                    bottleneck_dim=d_model)
            self._output_dim = self.conv.output_dim
        else:
            self._output_dim = input_dim * n_splices * n_stacks
            self.conv = None

            self.embed_in = LinearND(
                self._output_dim, d_model,
                dropout=0)  # NOTE: do not apply dropout here

        if pe_type:
            self.pos_emb_in = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layer_norm_in = nn.LayerNorm(d_model, eps=layer_norm_eps)

        # Self-attention layers
        self.layers = nn.ModuleList([
            TransformerEncoderBlock(d_model, d_ff, attn_type, attn_n_heads,
                                    dropout, dropout_att, layer_norm_eps)
            for l in range(n_layers)
        ])
        self.layer_norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._output_dim = d_model
Beispiel #9
0
    def __init__(self, input_dim, attn_type, n_heads, n_layers, d_model, d_ff,
                 last_proj_dim, pe_type, layer_norm_eps, ffn_activation,
                 dropout_in, dropout, dropout_att, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes,
                 conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm,
                 conv_bottleneck_dim, conv_param_init, param_init,
                 chunk_size_left, chunk_size_current, chunk_size_right):

        super(TransformerEncoder, self).__init__()

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.chunk_size_left = chunk_size_left
        self.chunk_size_current = chunk_size_current
        self.chunk_size_right = chunk_size_right

        # Setting for CNNs before RNNs
        if conv_channels:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type)
        self.layers = repeat(
            TransformerEncoderBlock(d_model, d_ff, attn_type, n_heads, dropout,
                                    dropout_att, layer_norm_eps,
                                    ffn_activation, param_init), n_layers)
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        if last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim
        else:
            self.bridge = None
            self._odim = d_model

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor()

        if param_init == 'xavier_uniform':
            self.reset_parameters()