コード例 #1
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 ffn_activation, pe_type, layer_norm_eps, last_proj_dim,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init, clamp_len,
                 lookahead, chunk_size_left, chunk_size_current,
                 chunk_size_right, streaming_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s
        # parse lookahead
        lookaheads = [0] * n_layers
        for lth, s in enumerate(list(map(int,
                                         lookahead.split('_')[:n_layers]))):
            lookaheads[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise Warning(
                'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d'
                % (n_layers, n_layers_sub1))
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise Warning(
                'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d'
                % (n_layers_sub1, n_layers_sub2))

        self.enc_type = enc_type
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for compatibility
        chunk_size_left = str(chunk_size_left)
        chunk_size_current = str(chunk_size_current)
        chunk_size_right = str(chunk_size_right)

        # for streaming encoder
        self.unidir = 'uni' in enc_type
        self.lookaheads = lookaheads
        if sum(lookaheads) > 0:
            assert self.unidir
        self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks
        self.chunk_size_current = int(
            chunk_size_current.split('_')[-1]) // n_stacks
        self.chunk_size_right = int(
            chunk_size_right.split('_')[-1]) // n_stacks
        self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type
        self.cnn_lookahead = self.unidir or enc_type == 'conv'
        self.streaming_type = streaming_type if self.lc_bidir else ''
        # -: past context
        # *: current context
        # +: future context
        # reshape) overlapped windowing. additional redundant computation is introduced.
        # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short
        # and independent on layer depth, the overhead is negligible.
        # chunk1: |**|++
        # chunk2:  --|**|++
        # chunk3:     --|**|++
        # chunk4:        --|**|++
        # chunk5:           --|**|++
        # mask) chunkwise masking. future context is restricted within the current chunk
        # to avoid accumuration of future context depending on the layer depth.
        # chunk1: |**|
        # chunk2:  --|**|
        # chunk3:  -- --|**|
        # chunk4:     -- --|**|
        # chunk5:        -- --|**|
        if self.unidir:
            assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0
        if self.streaming_type == 'mask':
            assert self.chunk_size_right == 0
            assert self.chunk_size_left == self.chunk_size_current
            # NOTE: this is important to cache CNN output at each chunk
        if self.lc_bidir:
            assert n_layers_sub1 == 0
            assert n_layers_sub2 == 0
            assert not self.unidir

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'add':
                self.subsample = nn.ModuleList(
                    [AddSubsampler(factor) for factor in subsamples])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_enc, self.pos_emb = None, None
        self.u_bias, self.v_bias = None, None
        if pe_type in ['relative', 'relative_xl']:
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            if pe_type == 'relative_xl':
                self.u_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                self.v_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                # NOTE: u_bias and v_bias are global parameters shared in the whole model
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, n_heads, dropout,
                                        dropout_att, dropout_layer,
                                        layer_norm_eps, ffn_activation,
                                        param_init, pe_type, clamp_len,
                                        ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub1 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)
                odim_sub1 = last_proj_dim
            if n_layers_sub1 == n_layers:
                self.norm_out_sub1 = None
            else:
                self.norm_out_sub1 = nn.LayerNorm(odim_sub1,
                                                  eps=layer_norm_eps)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub2 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)
                odim_sub2 = last_proj_dim
            if n_layers_sub2 == n_layers:
                self.norm_out_sub2 = None
            else:
                self.norm_out_sub2 = nn.LayerNorm(odim_sub2,
                                                  eps=layer_norm_eps)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()
コード例 #2
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 last_proj_dim, pe_type, layer_norm_eps, ffn_activation,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init,
                 chunk_size_left, chunk_size_current, chunk_size_right,
                 latency_control_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')
        assert enc_type in ['transformer', 'conv_transformer']

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for streaming encoder
        self.chunk_size_left = chunk_size_left
        self.chunk_size_current = chunk_size_current
        self.chunk_size_right = chunk_size_right
        self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0
        self.lc_type = latency_control_type
        # reshape) not lookahead frames in CNN layers, but requires some additional computations
        # mask) there are some lookahead frames in CNN layers, no additional computations

        # TransformerXL like streaming encoder
        self.memory_transformer = ('transformer_xl' in enc_type)
        self.mem_len = chunk_size_left
        if self.memory_transformer:
            assert pe_type == 'relative'
            assert chunk_size_left > 0
            assert chunk_size_current > 0

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_emb = None
        self.u = None
        self.v = None
        if self.memory_transformer:
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
            self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
            # NOTE: u and v are global parameters
        elif pe_type == 'relative':
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model,
                                        d_ff,
                                        n_heads,
                                        dropout,
                                        dropout_att,
                                        dropout_layer,
                                        layer_norm_eps,
                                        ffn_activation,
                                        param_init,
                                        relative_attention=self.pos_emb
                                        is not None,
                                        ffn_bottleneck_dim=ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)
コード例 #3
0
ファイル: rnn.py プロジェクト: pradipcyb/neural_sp
    def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim,
                 n_layers, n_layers_sub1, n_layers_sub2, dropout_in, dropout,
                 subsample, subsample_type, n_stacks, n_splices,
                 conv_in_channel, conv_channels, conv_kernel_sizes,
                 conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm,
                 conv_bottleneck_dim, bidir_sum_fwd_bwd, task_specific_layer,
                 param_init, chunk_size_left, chunk_size_right):

        super(RNNEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError(
                'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d'
                % (n_layers, n_layers_sub1))
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError(
                'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d'
                % (n_layers_sub1, n_layers_sub2))

        self.enc_type = enc_type
        self.bidirectional = True if ('blstm' in enc_type
                                      or 'bgru' in enc_type) else False
        self.n_units = n_units
        self.n_dirs = 2 if self.bidirectional else 1
        self.n_layers = n_layers
        self.bidir_sum = bidir_sum_fwd_bwd

        # for latency-controlled
        self.chunk_size_left = int(chunk_size_left.split('_')[0]) // n_stacks
        self.chunk_size_right = int(chunk_size_right.split('_')[0]) // n_stacks
        self.lc_bidir = self.chunk_size_left > 0 or self.chunk_size_right > 0
        if self.lc_bidir:
            assert enc_type not in ['lstm', 'gru', 'conv_lstm', 'conv_gru']
            assert n_layers_sub2 == 0

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # Dropout for input-hidden connection
        self.dropout_in = nn.Dropout(p=dropout_in)

        if 'conv' in enc_type:
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    residual=False,
                                    bottleneck_dim=conv_bottleneck_dim,
                                    param_init=param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks

        if enc_type != 'conv':
            self.rnn = nn.ModuleList()
            if self.lc_bidir:
                self.rnn_bwd = nn.ModuleList()
            self.dropout = nn.Dropout(p=dropout)
            self.proj = nn.ModuleList() if n_projs > 0 else None
            self.subsample = nn.ModuleList(
            ) if np.prod(subsamples) > 1 else None
            self.padding = Padding(bidir_sum_fwd_bwd=bidir_sum_fwd_bwd
                                   if not self.lc_bidir else False)

            for lth in range(n_layers):
                if 'lstm' in enc_type:
                    rnn_i = nn.LSTM
                elif 'gru' in enc_type:
                    rnn_i = nn.GRU
                else:
                    raise ValueError(
                        'enc_type must be "(conv_)(b)lstm" or "(conv_)(b)gru".'
                    )

                if self.lc_bidir:
                    self.rnn += [
                        rnn_i(self._odim, n_units, 1, batch_first=True)
                    ]
                    self.rnn_bwd += [
                        rnn_i(self._odim, n_units, 1, batch_first=True)
                    ]
                else:
                    self.rnn += [
                        rnn_i(self._odim,
                              n_units,
                              1,
                              batch_first=True,
                              bidirectional=self.bidirectional)
                    ]
                self._odim = n_units if bidir_sum_fwd_bwd else n_units * self.n_dirs

                # Projection layer
                if self.proj is not None:
                    if lth != n_layers - 1:
                        self.proj += [nn.Linear(self._odim, n_projs)]
                        self._odim = n_projs

                # subsample
                if np.prod(subsamples) > 1:
                    if subsample_type == 'max_pool':
                        self.subsample += [MaxpoolSubsampler(subsamples[lth])]
                    elif subsample_type == 'concat':
                        self.subsample += [
                            ConcatSubsampler(subsamples[lth], self._odim)
                        ]
                    elif subsample_type == 'drop':
                        self.subsample += [DropSubsampler(subsamples[lth])]
                    elif subsample_type == '1dconv':
                        self.subsample += [
                            Conv1dSubsampler(subsamples[lth], self._odim)
                        ]
                    elif subsample_type == 'add':
                        self.subsample += [AddSubsampler(subsamples[lth])]

                # Task specific layer
                if lth == n_layers_sub1 - 1 and task_specific_layer:
                    self.rnn_sub1 = rnn_i(self._odim,
                                          n_units,
                                          1,
                                          batch_first=True,
                                          bidirectional=self.bidirectional)
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub1 = nn.Linear(n_units, last_proj_dim)
                if lth == n_layers_sub2 - 1 and task_specific_layer:
                    assert not self.lc_bidir
                    self.rnn_sub2 = rnn_i(self._odim,
                                          n_units,
                                          1,
                                          batch_first=True,
                                          bidirectional=self.bidirectional)
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub2 = nn.Linear(n_units, last_proj_dim)

            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge = nn.Linear(self._odim, last_proj_dim)
                self._odim = last_proj_dim

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        elif np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
        # NOTE: subsampling factor for frame stacking should not be included here
        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()
コード例 #4
0
ファイル: rnn.py プロジェクト: ishine/neural_sp
    def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim,
                 n_layers, n_layers_sub1, n_layers_sub2, dropout_in, dropout,
                 subsample, subsample_type, n_stacks, n_splices, frontend_conv,
                 bidir_sum_fwd_bwd, task_specific_layer, param_init,
                 chunk_size_current, chunk_size_right, cnn_lookahead,
                 rsp_prob):

        super(RNNEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise Warning(
                'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d'
                % (n_layers, n_layers_sub1))
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise Warning(
                'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d'
                % (n_layers_sub1, n_layers_sub2))

        self.enc_type = enc_type
        self.bidirectional = True if 'blstm' in enc_type else False
        self.n_units = n_units
        self.n_dirs = 2 if self.bidirectional else 1
        self.n_layers = n_layers
        self.bidir_sum = bidir_sum_fwd_bwd

        # for compatiblity
        chunk_size_current = str(chunk_size_current)
        chunk_size_right = str(chunk_size_right)

        # for latency-controlled
        self.N_c = int(chunk_size_current.split('_')[0]) // n_stacks
        self.N_r = int(chunk_size_right.split('_')[0]) // n_stacks
        self.lc_bidir = (self.N_c > 0 or self.N_r > 0) and self.bidirectional
        if self.lc_bidir:
            assert enc_type not in ['lstm', 'conv_lstm']
            assert n_layers_sub2 == 0

        # for streaming
        self.rsp_prob = rsp_prob

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # Dropout for input-hidden connection
        self.dropout_in = nn.Dropout(p=dropout_in)

        self.conv = frontend_conv
        if self.conv is not None:
            self._odim = self.conv.output_dim
        else:
            self._odim = input_dim * n_splices * n_stacks
        self.cnn_lookahead = cnn_lookahead
        if not cnn_lookahead:
            assert self.N_c > 0
            assert self.lc_bidir

        if enc_type != 'conv':
            self.rnn = nn.ModuleList()
            if self.lc_bidir:
                self.rnn_bwd = nn.ModuleList()
            self.dropout = nn.Dropout(p=dropout)
            self.proj = nn.ModuleList() if n_projs > 0 else None
            self.subsample = nn.ModuleList(
            ) if np.prod(subsamples) > 1 else None
            self.padding = Padding(bidir_sum_fwd_bwd=bidir_sum_fwd_bwd
                                   if not self.lc_bidir else False)

            for lth in range(n_layers):
                if self.lc_bidir:
                    self.rnn += [
                        nn.LSTM(self._odim, n_units, 1, batch_first=True)
                    ]
                    self.rnn_bwd += [
                        nn.LSTM(self._odim, n_units, 1, batch_first=True)
                    ]
                else:
                    self.rnn += [
                        nn.LSTM(self._odim,
                                n_units,
                                1,
                                batch_first=True,
                                bidirectional=self.bidirectional)
                    ]
                self._odim = n_units if bidir_sum_fwd_bwd else n_units * self.n_dirs

                # Task specific layer
                if lth == n_layers_sub1 - 1 and task_specific_layer:
                    self.layer_sub1 = nn.Linear(self._odim, n_units)
                    self._odim_sub1 = n_units
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub1 = nn.Linear(n_units, last_proj_dim)
                        self._odim_sub1 = last_proj_dim
                if lth == n_layers_sub2 - 1 and task_specific_layer:
                    assert not self.lc_bidir
                    self.layer_sub2 = nn.Linear(self._odim, n_units)
                    self._odim_sub2 = n_units
                    if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                        self.bridge_sub2 = nn.Linear(n_units, last_proj_dim)
                        self._odim_sub2 = last_proj_dim

                # Projection layer
                if self.proj is not None:
                    if lth != n_layers - 1:
                        self.proj += [nn.Linear(self._odim, n_projs)]
                        self._odim = n_projs

                # subsample
                if np.prod(subsamples) > 1:
                    if subsample_type == 'max_pool':
                        self.subsample += [MaxPoolSubsampler(subsamples[lth])]
                    elif subsample_type == 'mean_pool':
                        self.subsample += [MeanPoolSubsampler(subsamples[lth])]
                    elif subsample_type == 'concat':
                        self.subsample += [
                            ConcatSubsampler(subsamples[lth], self._odim)
                        ]
                    elif subsample_type == 'drop':
                        self.subsample += [DropSubsampler(subsamples[lth])]
                    elif subsample_type == 'conv1d':
                        self.subsample += [
                            Conv1dSubsampler(subsamples[lth], self._odim)
                        ]
                    elif subsample_type == 'add':
                        self.subsample += [AddSubsampler(subsamples[lth])]

            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge = nn.Linear(self._odim, last_proj_dim)
                self._odim = last_proj_dim

        # calculate subsampling factor
        self.conv_factor = self.conv.subsampling_factor if self.conv is not None else 1
        self._factor = self.conv_factor
        self._factor_sub1 = self.conv_factor
        self._factor_sub2 = self.conv_factor
        if n_layers_sub1 > 1:
            self._factor_sub1 *= np.prod(subsamples[:n_layers_sub1 - 1])
        if n_layers_sub2 > 1:
            self._factor_sub1 *= np.prod(subsamples[:n_layers_sub2 - 1])
        self._factor *= np.prod(subsamples)
        # NOTE: subsampling factor for frame stacking should not be included here
        if self.N_c > 0:
            assert self.N_c % self._factor == 0
        if self.N_r > 0:
            assert self.N_r % self._factor == 0

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()