def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, ffn_activation, pe_type, layer_norm_eps, last_proj_dim, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, clamp_len, lookahead, chunk_size_left, chunk_size_current, chunk_size_right, streaming_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s # parse lookahead lookaheads = [0] * n_layers for lth, s in enumerate(list(map(int, lookahead.split('_')[:n_layers]))): lookaheads[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise Warning( 'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d' % (n_layers, n_layers_sub1)) if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise Warning( 'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d' % (n_layers_sub1, n_layers_sub2)) self.enc_type = enc_type self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for compatibility chunk_size_left = str(chunk_size_left) chunk_size_current = str(chunk_size_current) chunk_size_right = str(chunk_size_right) # for streaming encoder self.unidir = 'uni' in enc_type self.lookaheads = lookaheads if sum(lookaheads) > 0: assert self.unidir self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks self.chunk_size_current = int( chunk_size_current.split('_')[-1]) // n_stacks self.chunk_size_right = int( chunk_size_right.split('_')[-1]) // n_stacks self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type self.cnn_lookahead = self.unidir or enc_type == 'conv' self.streaming_type = streaming_type if self.lc_bidir else '' # -: past context # *: current context # +: future context # reshape) overlapped windowing. additional redundant computation is introduced. # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short # and independent on layer depth, the overhead is negligible. # chunk1: |**|++ # chunk2: --|**|++ # chunk3: --|**|++ # chunk4: --|**|++ # chunk5: --|**|++ # mask) chunkwise masking. future context is restricted within the current chunk # to avoid accumuration of future context depending on the layer depth. # chunk1: |**| # chunk2: --|**| # chunk3: -- --|**| # chunk4: -- --|**| # chunk5: -- --|**| if self.unidir: assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0 if self.streaming_type == 'mask': assert self.chunk_size_right == 0 assert self.chunk_size_left == self.chunk_size_current # NOTE: this is important to cache CNN output at each chunk if self.lc_bidir: assert n_layers_sub1 == 0 assert n_layers_sub2 == 0 assert not self.unidir # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'add': self.subsample = nn.ModuleList( [AddSubsampler(factor) for factor in subsamples]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_enc, self.pos_emb = None, None self.u_bias, self.v_bias = None, None if pe_type in ['relative', 'relative_xl']: self.pos_emb = XLPositionalEmbedding(d_model, dropout) if pe_type == 'relative_xl': self.u_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) self.v_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u_bias and v_bias are global parameters shared in the whole model else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub1 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) odim_sub1 = last_proj_dim if n_layers_sub1 == n_layers: self.norm_out_sub1 = None else: self.norm_out_sub1 = nn.LayerNorm(odim_sub1, eps=layer_norm_eps) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub2 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) odim_sub2 = last_proj_dim if n_layers_sub2 == n_layers: self.norm_out_sub2 = None else: self.norm_out_sub2 = nn.LayerNorm(odim_sub2, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init) # for streaming inference self.reset_cache()
def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, last_proj_dim, pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, chunk_size_left, chunk_size_current, chunk_size_right, latency_control_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') assert enc_type in ['transformer', 'conv_transformer'] self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for streaming encoder self.chunk_size_left = chunk_size_left self.chunk_size_current = chunk_size_current self.chunk_size_right = chunk_size_right self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0 self.lc_type = latency_control_type # reshape) not lookahead frames in CNN layers, but requires some additional computations # mask) there are some lookahead frames in CNN layers, no additional computations # TransformerXL like streaming encoder self.memory_transformer = ('transformer_xl' in enc_type) self.mem_len = chunk_size_left if self.memory_transformer: assert pe_type == 'relative' assert chunk_size_left > 0 assert chunk_size_current > 0 # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_emb = None self.u = None self.v = None if self.memory_transformer: self.pos_emb = XLPositionalEmbedding(d_model, dropout) self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u and v are global parameters elif pe_type == 'relative': self.pos_emb = XLPositionalEmbedding(d_model, dropout) else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, relative_attention=self.pos_emb is not None, ffn_bottleneck_dim=ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init)
def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim, n_layers, n_layers_sub1, n_layers_sub2, dropout_in, dropout, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, bidir_sum_fwd_bwd, task_specific_layer, param_init, chunk_size_left, chunk_size_right): super(RNNEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError( 'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d' % (n_layers, n_layers_sub1)) if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError( 'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d' % (n_layers_sub1, n_layers_sub2)) self.enc_type = enc_type self.bidirectional = True if ('blstm' in enc_type or 'bgru' in enc_type) else False self.n_units = n_units self.n_dirs = 2 if self.bidirectional else 1 self.n_layers = n_layers self.bidir_sum = bidir_sum_fwd_bwd # for latency-controlled self.chunk_size_left = int(chunk_size_left.split('_')[0]) // n_stacks self.chunk_size_right = int(chunk_size_right.split('_')[0]) // n_stacks self.lc_bidir = self.chunk_size_left > 0 or self.chunk_size_right > 0 if self.lc_bidir: assert enc_type not in ['lstm', 'gru', 'conv_lstm', 'conv_gru'] assert n_layers_sub2 == 0 # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # Dropout for input-hidden connection self.dropout_in = nn.Dropout(p=dropout_in) if 'conv' in enc_type: assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, residual=False, bottleneck_dim=conv_bottleneck_dim, param_init=param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks if enc_type != 'conv': self.rnn = nn.ModuleList() if self.lc_bidir: self.rnn_bwd = nn.ModuleList() self.dropout = nn.Dropout(p=dropout) self.proj = nn.ModuleList() if n_projs > 0 else None self.subsample = nn.ModuleList( ) if np.prod(subsamples) > 1 else None self.padding = Padding(bidir_sum_fwd_bwd=bidir_sum_fwd_bwd if not self.lc_bidir else False) for lth in range(n_layers): if 'lstm' in enc_type: rnn_i = nn.LSTM elif 'gru' in enc_type: rnn_i = nn.GRU else: raise ValueError( 'enc_type must be "(conv_)(b)lstm" or "(conv_)(b)gru".' ) if self.lc_bidir: self.rnn += [ rnn_i(self._odim, n_units, 1, batch_first=True) ] self.rnn_bwd += [ rnn_i(self._odim, n_units, 1, batch_first=True) ] else: self.rnn += [ rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) ] self._odim = n_units if bidir_sum_fwd_bwd else n_units * self.n_dirs # Projection layer if self.proj is not None: if lth != n_layers - 1: self.proj += [nn.Linear(self._odim, n_projs)] self._odim = n_projs # subsample if np.prod(subsamples) > 1: if subsample_type == 'max_pool': self.subsample += [MaxpoolSubsampler(subsamples[lth])] elif subsample_type == 'concat': self.subsample += [ ConcatSubsampler(subsamples[lth], self._odim) ] elif subsample_type == 'drop': self.subsample += [DropSubsampler(subsamples[lth])] elif subsample_type == '1dconv': self.subsample += [ Conv1dSubsampler(subsamples[lth], self._odim) ] elif subsample_type == 'add': self.subsample += [AddSubsampler(subsamples[lth])] # Task specific layer if lth == n_layers_sub1 - 1 and task_specific_layer: self.rnn_sub1 = rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(n_units, last_proj_dim) if lth == n_layers_sub2 - 1 and task_specific_layer: assert not self.lc_bidir self.rnn_sub2 = rnn_i(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(n_units, last_proj_dim) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor elif np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) # NOTE: subsampling factor for frame stacking should not be included here if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.reset_parameters(param_init) # for streaming inference self.reset_cache()
def __init__(self, input_dim, enc_type, n_units, n_projs, last_proj_dim, n_layers, n_layers_sub1, n_layers_sub2, dropout_in, dropout, subsample, subsample_type, n_stacks, n_splices, frontend_conv, bidir_sum_fwd_bwd, task_specific_layer, param_init, chunk_size_current, chunk_size_right, cnn_lookahead, rsp_prob): super(RNNEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise Warning( 'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d' % (n_layers, n_layers_sub1)) if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise Warning( 'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d' % (n_layers_sub1, n_layers_sub2)) self.enc_type = enc_type self.bidirectional = True if 'blstm' in enc_type else False self.n_units = n_units self.n_dirs = 2 if self.bidirectional else 1 self.n_layers = n_layers self.bidir_sum = bidir_sum_fwd_bwd # for compatiblity chunk_size_current = str(chunk_size_current) chunk_size_right = str(chunk_size_right) # for latency-controlled self.N_c = int(chunk_size_current.split('_')[0]) // n_stacks self.N_r = int(chunk_size_right.split('_')[0]) // n_stacks self.lc_bidir = (self.N_c > 0 or self.N_r > 0) and self.bidirectional if self.lc_bidir: assert enc_type not in ['lstm', 'conv_lstm'] assert n_layers_sub2 == 0 # for streaming self.rsp_prob = rsp_prob # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # Dropout for input-hidden connection self.dropout_in = nn.Dropout(p=dropout_in) self.conv = frontend_conv if self.conv is not None: self._odim = self.conv.output_dim else: self._odim = input_dim * n_splices * n_stacks self.cnn_lookahead = cnn_lookahead if not cnn_lookahead: assert self.N_c > 0 assert self.lc_bidir if enc_type != 'conv': self.rnn = nn.ModuleList() if self.lc_bidir: self.rnn_bwd = nn.ModuleList() self.dropout = nn.Dropout(p=dropout) self.proj = nn.ModuleList() if n_projs > 0 else None self.subsample = nn.ModuleList( ) if np.prod(subsamples) > 1 else None self.padding = Padding(bidir_sum_fwd_bwd=bidir_sum_fwd_bwd if not self.lc_bidir else False) for lth in range(n_layers): if self.lc_bidir: self.rnn += [ nn.LSTM(self._odim, n_units, 1, batch_first=True) ] self.rnn_bwd += [ nn.LSTM(self._odim, n_units, 1, batch_first=True) ] else: self.rnn += [ nn.LSTM(self._odim, n_units, 1, batch_first=True, bidirectional=self.bidirectional) ] self._odim = n_units if bidir_sum_fwd_bwd else n_units * self.n_dirs # Task specific layer if lth == n_layers_sub1 - 1 and task_specific_layer: self.layer_sub1 = nn.Linear(self._odim, n_units) self._odim_sub1 = n_units if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(n_units, last_proj_dim) self._odim_sub1 = last_proj_dim if lth == n_layers_sub2 - 1 and task_specific_layer: assert not self.lc_bidir self.layer_sub2 = nn.Linear(self._odim, n_units) self._odim_sub2 = n_units if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(n_units, last_proj_dim) self._odim_sub2 = last_proj_dim # Projection layer if self.proj is not None: if lth != n_layers - 1: self.proj += [nn.Linear(self._odim, n_projs)] self._odim = n_projs # subsample if np.prod(subsamples) > 1: if subsample_type == 'max_pool': self.subsample += [MaxPoolSubsampler(subsamples[lth])] elif subsample_type == 'mean_pool': self.subsample += [MeanPoolSubsampler(subsamples[lth])] elif subsample_type == 'concat': self.subsample += [ ConcatSubsampler(subsamples[lth], self._odim) ] elif subsample_type == 'drop': self.subsample += [DropSubsampler(subsamples[lth])] elif subsample_type == 'conv1d': self.subsample += [ Conv1dSubsampler(subsamples[lth], self._odim) ] elif subsample_type == 'add': self.subsample += [AddSubsampler(subsamples[lth])] if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim # calculate subsampling factor self.conv_factor = self.conv.subsampling_factor if self.conv is not None else 1 self._factor = self.conv_factor self._factor_sub1 = self.conv_factor self._factor_sub2 = self.conv_factor if n_layers_sub1 > 1: self._factor_sub1 *= np.prod(subsamples[:n_layers_sub1 - 1]) if n_layers_sub2 > 1: self._factor_sub1 *= np.prod(subsamples[:n_layers_sub2 - 1]) self._factor *= np.prod(subsamples) # NOTE: subsampling factor for frame stacking should not be included here if self.N_c > 0: assert self.N_c % self._factor == 0 if self.N_r > 0: assert self.N_r % self._factor == 0 self.reset_parameters(param_init) # for streaming inference self.reset_cache()