Exemplo n.º 1
0
    def __init__(self, special_symbols,
                 enc_n_units, attn_type, n_heads, n_layers,
                 d_model, d_ff, ffn_bottleneck_dim,
                 pe_type, layer_norm_eps, ffn_activation,
                 vocab, tie_embedding,
                 dropout, dropout_emb, dropout_att, dropout_layer, dropout_head,
                 lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list, backward,
                 global_weight, mtl_per_batch, param_init,
                 memory_transformer, mem_len,
                 mocha_chunk_size, mocha_n_heads_mono, mocha_n_heads_chunk,
                 mocha_init_r, mocha_eps, mocha_std,
                 mocha_no_denominator, mocha_1dconv,
                 mocha_quantity_loss_weight, mocha_head_divergence_loss_weight,
                 latency_metric, latency_loss_weight,
                 mocha_first_layer, share_chunkwise_attention,
                 external_lm, lm_fusion):

        super(TransformerDecoder, self).__init__()

        self.eos = special_symbols['eos']
        self.unk = special_symbols['unk']
        self.pad = special_symbols['pad']
        self.blank = special_symbols['blank']
        self.vocab = vocab
        self.enc_n_units = enc_n_units
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.lsm_prob = lsm_prob
        self.att_weight = global_weight - ctc_weight
        self.ctc_weight = ctc_weight
        self.bwd = backward
        self.mtl_per_batch = mtl_per_batch

        self.prev_spk = ''
        self.lmstate_final = None

        # for TransformerXL decoder
        self.memory_transformer = memory_transformer
        self.mem_len = mem_len
        if memory_transformer:
            assert pe_type == 'none'

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # for MMA
        self.attn_type = attn_type
        self.quantity_loss_weight = mocha_quantity_loss_weight
        self._quantity_loss_weight = 0  # for curriculum
        self.mocha_first_layer = mocha_first_layer

        self.headdiv_loss_weight = mocha_head_divergence_loss_weight
        self.latency_metric = latency_metric
        self.latency_loss_weight = latency_loss_weight
        self.ctc_trigger = (self.latency_metric in ['ctc_sync'])
        if self.ctc_trigger:
            assert 0 < self.ctc_weight < 1

        if ctc_weight > 0:
            self.ctc = CTC(eos=self.eos,
                           blank=self.blank,
                           enc_n_units=enc_n_units,
                           vocab=self.vocab,
                           dropout=dropout,
                           lsm_prob=ctc_lsm_prob,
                           fc_list=ctc_fc_list,
                           param_init=0.1,
                           backward=backward)

        if self.att_weight > 0:
            # token embedding
            self.embed = nn.Embedding(self.vocab, d_model, padding_idx=self.pad)
            self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type, param_init)
            # positional embedding
            self.u = None
            self.v = None
            if memory_transformer:
                self.scale = math.sqrt(d_model)  # for token embedding
                self.dropout_emb = nn.Dropout(p=dropout_emb)  # for token embedding
                self.pos_emb = XLPositionalEmbedding(d_model, dropout_emb)
                if self.mem_len > 0:
                    self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
                    self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
                    # NOTE: u and v are global parameters
            # self-attention
            assert mocha_first_layer <= n_layers
            self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderBlock(
                d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_layer,
                layer_norm_eps, ffn_activation, param_init,
                src_tgt_attention=False if lth < mocha_first_layer - 1 else True,
                memory_transformer=memory_transformer,
                mocha_chunk_size=mocha_chunk_size,
                mocha_n_heads_mono=mocha_n_heads_mono,
                mocha_n_heads_chunk=mocha_n_heads_chunk,
                mocha_init_r=mocha_init_r,
                mocha_eps=mocha_eps,
                mocha_std=mocha_std,
                mocha_no_denominator=mocha_no_denominator,
                mocha_1dconv=mocha_1dconv,
                dropout_head=dropout_head,
                lm_fusion=lm_fusion,
                ffn_bottleneck_dim=ffn_bottleneck_dim,
                share_chunkwise_attention=share_chunkwise_attention)) for lth in range(n_layers)])
            self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
            self.output = nn.Linear(d_model, self.vocab)
            if tie_embedding:
                self.output.weight = self.embed.weight

            self.lm = external_lm
            if external_lm is not None:
                self.lm_output_proj = nn.Linear(external_lm.output_dim, d_model)

            self.reset_parameters(param_init)
Exemplo n.º 2
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 ffn_activation, pe_type, layer_norm_eps, last_proj_dim,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init, clamp_len,
                 lookahead, chunk_size_left, chunk_size_current,
                 chunk_size_right, streaming_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s
        # parse lookahead
        lookaheads = [0] * n_layers
        for lth, s in enumerate(list(map(int,
                                         lookahead.split('_')[:n_layers]))):
            lookaheads[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise Warning(
                'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d'
                % (n_layers, n_layers_sub1))
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise Warning(
                'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d'
                % (n_layers_sub1, n_layers_sub2))

        self.enc_type = enc_type
        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for compatibility
        chunk_size_left = str(chunk_size_left)
        chunk_size_current = str(chunk_size_current)
        chunk_size_right = str(chunk_size_right)

        # for streaming encoder
        self.unidir = 'uni' in enc_type
        self.lookaheads = lookaheads
        if sum(lookaheads) > 0:
            assert self.unidir
        self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks
        self.chunk_size_current = int(
            chunk_size_current.split('_')[-1]) // n_stacks
        self.chunk_size_right = int(
            chunk_size_right.split('_')[-1]) // n_stacks
        self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type
        self.cnn_lookahead = self.unidir or enc_type == 'conv'
        self.streaming_type = streaming_type if self.lc_bidir else ''
        # -: past context
        # *: current context
        # +: future context
        # reshape) overlapped windowing. additional redundant computation is introduced.
        # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short
        # and independent on layer depth, the overhead is negligible.
        # chunk1: |**|++
        # chunk2:  --|**|++
        # chunk3:     --|**|++
        # chunk4:        --|**|++
        # chunk5:           --|**|++
        # mask) chunkwise masking. future context is restricted within the current chunk
        # to avoid accumuration of future context depending on the layer depth.
        # chunk1: |**|
        # chunk2:  --|**|
        # chunk3:  -- --|**|
        # chunk4:     -- --|**|
        # chunk5:        -- --|**|
        if self.unidir:
            assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0
        if self.streaming_type == 'mask':
            assert self.chunk_size_right == 0
            assert self.chunk_size_left == self.chunk_size_current
            # NOTE: this is important to cache CNN output at each chunk
        if self.lc_bidir:
            assert n_layers_sub1 == 0
            assert n_layers_sub2 == 0
            assert not self.unidir

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'add':
                self.subsample = nn.ModuleList(
                    [AddSubsampler(factor) for factor in subsamples])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_enc, self.pos_emb = None, None
        self.u_bias, self.v_bias = None, None
        if pe_type in ['relative', 'relative_xl']:
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            if pe_type == 'relative_xl':
                self.u_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                self.v_bias = nn.Parameter(
                    torch.Tensor(n_heads, d_model // n_heads))
                # NOTE: u_bias and v_bias are global parameters shared in the whole model
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model, d_ff, n_heads, dropout,
                                        dropout_att, dropout_layer,
                                        layer_norm_eps, ffn_activation,
                                        param_init, pe_type, clamp_len,
                                        ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)
        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub1 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)
                odim_sub1 = last_proj_dim
            if n_layers_sub1 == n_layers:
                self.norm_out_sub1 = None
            else:
                self.norm_out_sub1 = nn.LayerNorm(odim_sub1,
                                                  eps=layer_norm_eps)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model, d_ff, n_heads, dropout, dropout_att,
                    dropout_layer, layer_norm_eps, ffn_activation, param_init,
                    pe_type, clamp_len, ffn_bottleneck_dim)
            odim_sub2 = d_model
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)
                odim_sub2 = last_proj_dim
            if n_layers_sub2 == n_layers:
                self.norm_out_sub2 = None
            else:
                self.norm_out_sub2 = nn.LayerNorm(odim_sub2,
                                                  eps=layer_norm_eps)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)

        # for streaming inference
        self.reset_cache()
Exemplo n.º 3
0
    def __init__(self, input_dim, enc_type, n_heads, kernel_size, n_layers,
                 n_layers_sub1, n_layers_sub2, d_model, d_ff,
                 ffn_bottleneck_dim, last_proj_dim, pe_type, layer_norm_eps,
                 ffn_activation, dropout_in, dropout, dropout_att,
                 dropout_layer, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init,
                 chunk_size_left, chunk_size_current, chunk_size_right):

        super(ConformerEncoder, self).__init__()

        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for streaming encoder
        self.chunk_size_left = chunk_size_left
        self.chunk_size_current = chunk_size_current
        self.chunk_size_right = chunk_size_right
        self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_emb = XLPositionalEmbedding(d_model, dropout)
        assert pe_type == 'relative'
        # TODO(hirofumi0810): try other positional encodings

        self.layers = nn.ModuleList([
            copy.deepcopy(
                ConformerEncoderBlock(d_model,
                                      d_ff,
                                      n_heads,
                                      kernel_size,
                                      dropout,
                                      dropout_att,
                                      dropout_layer,
                                      layer_norm_eps,
                                      ffn_activation,
                                      param_init,
                                      ffn_bottleneck_dim=ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = ConformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    kernel_size,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = ConformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    kernel_size,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)
Exemplo n.º 4
0
    def __init__(self, args, save_path=None):

        super(LMBase, self).__init__()
        logger.info(self.__class__.__name__)

        self.lm_type = args.lm_type
        self.save_path = save_path

        self.d_model = args.transformer_d_model
        self.n_layers = args.n_layers
        self.n_heads = args.transformer_n_heads
        self.lsm_prob = args.lsm_prob

        if args.mem_len > 0:
            self.mem_len = args.mem_len
        else:
            self.mem_len = args.bptt
        if args.recog_mem_len > 0:
            self.mem_len = args.recog_mem_len

        self.vocab = args.vocab
        self.eos = 2
        self.pad = 3
        # NOTE: reserved in advance

        # for cache
        self.cache_theta = 0.2  # smoothing parameter
        self.cache_lambda = 0.2  # cache weight
        self.cache_ids = []
        self.cache_keys = []
        self.cache_attn = []
        self.embed_cache = None

        # positional embedding
        self.pos_emb = XLPositionalEmbedding(self.d_model, args.dropout_in)
        self.u_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_model // self.n_heads))
        self.v_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_model // self.n_heads))
        # NOTE: u_bias and v_bias are global parameters

        self.embed = nn.Embedding(self.vocab, self.d_model, padding_idx=self.pad)
        self.scale = math.sqrt(self.d_model)  # for token embedding
        self.dropout_emb = nn.Dropout(p=args.dropout_in)  # for token embedding
        self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderBlock(
            self.d_model, args.transformer_d_ff, 'scaled_dot',
            self.n_heads, args.dropout_hidden, args.dropout_att, args.dropout_layer,
            args.transformer_layer_norm_eps, args.transformer_ffn_activation, args.transformer_param_init,
            src_tgt_attention=False, memory_transformer=True)) for lth in range(self.n_layers)])
        self.norm_out = nn.LayerNorm(self.d_model, eps=args.transformer_layer_norm_eps)

        self.adaptive_softmax = None
        self.output = None
        if args.adaptive_softmax:
            self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss(
                self.d_model, self.vocab,
                cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)],
                # cutoffs=[self.vocab // 25, 3 * self.vocab // 5],
                div_value=4.0)
        else:
            self.output = nn.Linear(self.d_model, self.vocab)
            if args.tie_embedding:
                self.output.weight = self.embed.weight

        self.reset_parameters()
Exemplo n.º 5
0
    def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1,
                 n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim,
                 last_proj_dim, pe_type, layer_norm_eps, ffn_activation,
                 dropout_in, dropout, dropout_att, dropout_layer, subsample,
                 subsample_type, n_stacks, n_splices, conv_in_channel,
                 conv_channels, conv_kernel_sizes, conv_strides, conv_poolings,
                 conv_batch_norm, conv_layer_norm, conv_bottleneck_dim,
                 conv_param_init, task_specific_layer, param_init,
                 chunk_size_left, chunk_size_current, chunk_size_right,
                 latency_control_type):

        super(TransformerEncoder, self).__init__()

        # parse subsample
        subsamples = [1] * n_layers
        for lth, s in enumerate(list(map(int,
                                         subsample.split('_')[:n_layers]))):
            subsamples[lth] = s

        if len(subsamples) > 0 and len(subsamples) != n_layers:
            raise ValueError(
                'subsample must be the same size as n_layers. n_layers: %d, subsample: %s'
                % (n_layers, subsamples))
        if n_layers_sub1 < 0 or (n_layers_sub1 > 1
                                 and n_layers < n_layers_sub1):
            raise ValueError('Set n_layers_sub1 between 1 to n_layers.')
        if n_layers_sub2 < 0 or (n_layers_sub2 > 1
                                 and n_layers_sub1 < n_layers_sub2):
            raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.')
        assert enc_type in ['transformer', 'conv_transformer']

        self.d_model = d_model
        self.n_layers = n_layers
        self.n_heads = n_heads
        self.pe_type = pe_type
        self.scale = math.sqrt(d_model)

        # for streaming encoder
        self.chunk_size_left = chunk_size_left
        self.chunk_size_current = chunk_size_current
        self.chunk_size_right = chunk_size_right
        self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0
        self.lc_type = latency_control_type
        # reshape) not lookahead frames in CNN layers, but requires some additional computations
        # mask) there are some lookahead frames in CNN layers, no additional computations

        # TransformerXL like streaming encoder
        self.memory_transformer = ('transformer_xl' in enc_type)
        self.mem_len = chunk_size_left
        if self.memory_transformer:
            assert pe_type == 'relative'
            assert chunk_size_left > 0
            assert chunk_size_current > 0

        # for hierarchical encoder
        self.n_layers_sub1 = n_layers_sub1
        self.n_layers_sub2 = n_layers_sub2
        self.task_specific_layer = task_specific_layer

        # for bridge layers
        self.bridge = None
        self.bridge_sub1 = None
        self.bridge_sub2 = None

        # for attention plot
        self.aws_dict = {}
        self.data_dict = {}

        # Setting for CNNs
        if 'conv' in enc_type:
            assert conv_channels
            assert n_stacks == 1 and n_splices == 1
            self.conv = ConvEncoder(input_dim,
                                    in_channel=conv_in_channel,
                                    channels=conv_channels,
                                    kernel_sizes=conv_kernel_sizes,
                                    strides=conv_strides,
                                    poolings=conv_poolings,
                                    dropout=0.,
                                    batch_norm=conv_batch_norm,
                                    layer_norm=conv_layer_norm,
                                    layer_norm_eps=layer_norm_eps,
                                    residual=False,
                                    bottleneck_dim=d_model,
                                    param_init=conv_param_init)
            self._odim = self.conv.output_dim
        else:
            self.conv = None
            self._odim = input_dim * n_splices * n_stacks
            self.embed = nn.Linear(self._odim, d_model)

        # calculate subsampling factor
        self._factor = 1
        if self.conv is not None:
            self._factor *= self.conv.subsampling_factor
        self.subsample = None
        if np.prod(subsamples) > 1:
            self._factor *= np.prod(subsamples)
            if subsample_type == 'max_pool':
                self.subsample = nn.ModuleList(
                    [MaxpoolSubsampler(factor) for factor in subsamples])
            elif subsample_type == 'concat':
                self.subsample = nn.ModuleList([
                    ConcatSubsampler(factor, self._odim)
                    for factor in subsamples
                ])
            elif subsample_type == 'drop':
                self.subsample = nn.ModuleList(
                    [DropSubsampler(factor) for factor in subsamples])
            elif subsample_type == '1dconv':
                self.subsample = nn.ModuleList([
                    Conv1dSubsampler(factor, self._odim)
                    for factor in subsamples
                ])

        if self.chunk_size_left > 0:
            assert self.chunk_size_left % self._factor == 0
        if self.chunk_size_current > 0:
            assert self.chunk_size_current % self._factor == 0
        if self.chunk_size_right > 0:
            assert self.chunk_size_right % self._factor == 0

        self.pos_emb = None
        self.u = None
        self.v = None
        if self.memory_transformer:
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
            self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
            self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads))
            # NOTE: u and v are global parameters
        elif pe_type == 'relative':
            self.pos_emb = XLPositionalEmbedding(d_model, dropout)
        else:
            self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type,
                                              param_init)

        self.layers = nn.ModuleList([
            copy.deepcopy(
                TransformerEncoderBlock(d_model,
                                        d_ff,
                                        n_heads,
                                        dropout,
                                        dropout_att,
                                        dropout_layer,
                                        layer_norm_eps,
                                        ffn_activation,
                                        param_init,
                                        relative_attention=self.pos_emb
                                        is not None,
                                        ffn_bottleneck_dim=ffn_bottleneck_dim))
            for _ in range(n_layers)
        ])
        self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps)

        self._odim = d_model

        if n_layers_sub1 > 0:
            if task_specific_layer:
                self.layer_sub1 = TransformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim)

        if n_layers_sub2 > 0:
            if task_specific_layer:
                self.layer_sub2 = TransformerEncoderBlock(
                    d_model,
                    d_ff,
                    n_heads,
                    dropout,
                    dropout_att,
                    dropout_layer,
                    layer_norm_eps,
                    ffn_activation,
                    param_init,
                    ffn_bottleneck_dim=ffn_bottleneck_dim)
            self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
            if last_proj_dim > 0 and last_proj_dim != self.output_dim:
                self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim)

        if last_proj_dim > 0 and last_proj_dim != self.output_dim:
            self.bridge = nn.Linear(self._odim, last_proj_dim)
            self._odim = last_proj_dim

        self.reset_parameters(param_init)