def __init__(self, special_symbols, enc_n_units, attn_type, n_heads, n_layers, d_model, d_ff, ffn_bottleneck_dim, pe_type, layer_norm_eps, ffn_activation, vocab, tie_embedding, dropout, dropout_emb, dropout_att, dropout_layer, dropout_head, lsm_prob, ctc_weight, ctc_lsm_prob, ctc_fc_list, backward, global_weight, mtl_per_batch, param_init, memory_transformer, mem_len, mocha_chunk_size, mocha_n_heads_mono, mocha_n_heads_chunk, mocha_init_r, mocha_eps, mocha_std, mocha_no_denominator, mocha_1dconv, mocha_quantity_loss_weight, mocha_head_divergence_loss_weight, latency_metric, latency_loss_weight, mocha_first_layer, share_chunkwise_attention, external_lm, lm_fusion): super(TransformerDecoder, self).__init__() self.eos = special_symbols['eos'] self.unk = special_symbols['unk'] self.pad = special_symbols['pad'] self.blank = special_symbols['blank'] self.vocab = vocab self.enc_n_units = enc_n_units self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.lsm_prob = lsm_prob self.att_weight = global_weight - ctc_weight self.ctc_weight = ctc_weight self.bwd = backward self.mtl_per_batch = mtl_per_batch self.prev_spk = '' self.lmstate_final = None # for TransformerXL decoder self.memory_transformer = memory_transformer self.mem_len = mem_len if memory_transformer: assert pe_type == 'none' # for attention plot self.aws_dict = {} self.data_dict = {} # for MMA self.attn_type = attn_type self.quantity_loss_weight = mocha_quantity_loss_weight self._quantity_loss_weight = 0 # for curriculum self.mocha_first_layer = mocha_first_layer self.headdiv_loss_weight = mocha_head_divergence_loss_weight self.latency_metric = latency_metric self.latency_loss_weight = latency_loss_weight self.ctc_trigger = (self.latency_metric in ['ctc_sync']) if self.ctc_trigger: assert 0 < self.ctc_weight < 1 if ctc_weight > 0: self.ctc = CTC(eos=self.eos, blank=self.blank, enc_n_units=enc_n_units, vocab=self.vocab, dropout=dropout, lsm_prob=ctc_lsm_prob, fc_list=ctc_fc_list, param_init=0.1, backward=backward) if self.att_weight > 0: # token embedding self.embed = nn.Embedding(self.vocab, d_model, padding_idx=self.pad) self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type, param_init) # positional embedding self.u = None self.v = None if memory_transformer: self.scale = math.sqrt(d_model) # for token embedding self.dropout_emb = nn.Dropout(p=dropout_emb) # for token embedding self.pos_emb = XLPositionalEmbedding(d_model, dropout_emb) if self.mem_len > 0: self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u and v are global parameters # self-attention assert mocha_first_layer <= n_layers self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderBlock( d_model, d_ff, attn_type, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, src_tgt_attention=False if lth < mocha_first_layer - 1 else True, memory_transformer=memory_transformer, mocha_chunk_size=mocha_chunk_size, mocha_n_heads_mono=mocha_n_heads_mono, mocha_n_heads_chunk=mocha_n_heads_chunk, mocha_init_r=mocha_init_r, mocha_eps=mocha_eps, mocha_std=mocha_std, mocha_no_denominator=mocha_no_denominator, mocha_1dconv=mocha_1dconv, dropout_head=dropout_head, lm_fusion=lm_fusion, ffn_bottleneck_dim=ffn_bottleneck_dim, share_chunkwise_attention=share_chunkwise_attention)) for lth in range(n_layers)]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self.output = nn.Linear(d_model, self.vocab) if tie_embedding: self.output.weight = self.embed.weight self.lm = external_lm if external_lm is not None: self.lm_output_proj = nn.Linear(external_lm.output_dim, d_model) self.reset_parameters(param_init)
def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, ffn_activation, pe_type, layer_norm_eps, last_proj_dim, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, clamp_len, lookahead, chunk_size_left, chunk_size_current, chunk_size_right, streaming_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s # parse lookahead lookaheads = [0] * n_layers for lth, s in enumerate(list(map(int, lookahead.split('_')[:n_layers]))): lookaheads[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise Warning( 'Set n_layers_sub1 between 1 to n_layers. n_layers: %d, n_layers_sub1: %d' % (n_layers, n_layers_sub1)) if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise Warning( 'Set n_layers_sub2 between 1 to n_layers_sub1. n_layers_sub1: %d, n_layers_sub2: %d' % (n_layers_sub1, n_layers_sub2)) self.enc_type = enc_type self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for compatibility chunk_size_left = str(chunk_size_left) chunk_size_current = str(chunk_size_current) chunk_size_right = str(chunk_size_right) # for streaming encoder self.unidir = 'uni' in enc_type self.lookaheads = lookaheads if sum(lookaheads) > 0: assert self.unidir self.chunk_size_left = int(chunk_size_left.split('_')[-1]) // n_stacks self.chunk_size_current = int( chunk_size_current.split('_')[-1]) // n_stacks self.chunk_size_right = int( chunk_size_right.split('_')[-1]) // n_stacks self.lc_bidir = self.chunk_size_current > 0 and enc_type != 'conv' and 'uni' not in enc_type self.cnn_lookahead = self.unidir or enc_type == 'conv' self.streaming_type = streaming_type if self.lc_bidir else '' # -: past context # *: current context # +: future context # reshape) overlapped windowing. additional redundant computation is introduced. # During inference, caching is not applied. However, considering (N_l+N_c+N_r) is very short # and independent on layer depth, the overhead is negligible. # chunk1: |**|++ # chunk2: --|**|++ # chunk3: --|**|++ # chunk4: --|**|++ # chunk5: --|**|++ # mask) chunkwise masking. future context is restricted within the current chunk # to avoid accumuration of future context depending on the layer depth. # chunk1: |**| # chunk2: --|**| # chunk3: -- --|**| # chunk4: -- --|**| # chunk5: -- --|**| if self.unidir: assert self.chunk_size_left == self.chunk_size_current == self.chunk_size_right == 0 if self.streaming_type == 'mask': assert self.chunk_size_right == 0 assert self.chunk_size_left == self.chunk_size_current # NOTE: this is important to cache CNN output at each chunk if self.lc_bidir: assert n_layers_sub1 == 0 assert n_layers_sub2 == 0 assert not self.unidir # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'add': self.subsample = nn.ModuleList( [AddSubsampler(factor) for factor in subsamples]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_enc, self.pos_emb = None, None self.u_bias, self.v_bias = None, None if pe_type in ['relative', 'relative_xl']: self.pos_emb = XLPositionalEmbedding(d_model, dropout) if pe_type == 'relative_xl': self.u_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) self.v_bias = nn.Parameter( torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u_bias and v_bias are global parameters shared in the whole model else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub1 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) odim_sub1 = last_proj_dim if n_layers_sub1 == n_layers: self.norm_out_sub1 = None else: self.norm_out_sub1 = nn.LayerNorm(odim_sub1, eps=layer_norm_eps) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, pe_type, clamp_len, ffn_bottleneck_dim) odim_sub2 = d_model if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) odim_sub2 = last_proj_dim if n_layers_sub2 == n_layers: self.norm_out_sub2 = None else: self.norm_out_sub2 = nn.LayerNorm(odim_sub2, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init) # for streaming inference self.reset_cache()
def __init__(self, input_dim, enc_type, n_heads, kernel_size, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, last_proj_dim, pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout, dropout_att, dropout_layer, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, chunk_size_left, chunk_size_current, chunk_size_right): super(ConformerEncoder, self).__init__() if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for streaming encoder self.chunk_size_left = chunk_size_left self.chunk_size_current = chunk_size_current self.chunk_size_right = chunk_size_right self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0 # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_emb = XLPositionalEmbedding(d_model, dropout) assert pe_type == 'relative' # TODO(hirofumi0810): try other positional encodings self.layers = nn.ModuleList([ copy.deepcopy( ConformerEncoderBlock(d_model, d_ff, n_heads, kernel_size, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = ConformerEncoderBlock( d_model, d_ff, n_heads, kernel_size, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = ConformerEncoderBlock( d_model, d_ff, n_heads, kernel_size, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init)
def __init__(self, args, save_path=None): super(LMBase, self).__init__() logger.info(self.__class__.__name__) self.lm_type = args.lm_type self.save_path = save_path self.d_model = args.transformer_d_model self.n_layers = args.n_layers self.n_heads = args.transformer_n_heads self.lsm_prob = args.lsm_prob if args.mem_len > 0: self.mem_len = args.mem_len else: self.mem_len = args.bptt if args.recog_mem_len > 0: self.mem_len = args.recog_mem_len self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed_cache = None # positional embedding self.pos_emb = XLPositionalEmbedding(self.d_model, args.dropout_in) self.u_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_model // self.n_heads)) self.v_bias = nn.Parameter(torch.Tensor(self.n_heads, self.d_model // self.n_heads)) # NOTE: u_bias and v_bias are global parameters self.embed = nn.Embedding(self.vocab, self.d_model, padding_idx=self.pad) self.scale = math.sqrt(self.d_model) # for token embedding self.dropout_emb = nn.Dropout(p=args.dropout_in) # for token embedding self.layers = nn.ModuleList([copy.deepcopy(TransformerDecoderBlock( self.d_model, args.transformer_d_ff, 'scaled_dot', self.n_heads, args.dropout_hidden, args.dropout_att, args.dropout_layer, args.transformer_layer_norm_eps, args.transformer_ffn_activation, args.transformer_param_init, src_tgt_attention=False, memory_transformer=True)) for lth in range(self.n_layers)]) self.norm_out = nn.LayerNorm(self.d_model, eps=args.transformer_layer_norm_eps) self.adaptive_softmax = None self.output = None if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( self.d_model, self.vocab, cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)], # cutoffs=[self.vocab // 25, 3 * self.vocab // 5], div_value=4.0) else: self.output = nn.Linear(self.d_model, self.vocab) if args.tie_embedding: self.output.weight = self.embed.weight self.reset_parameters()
def __init__(self, input_dim, enc_type, n_heads, n_layers, n_layers_sub1, n_layers_sub2, d_model, d_ff, ffn_bottleneck_dim, last_proj_dim, pe_type, layer_norm_eps, ffn_activation, dropout_in, dropout, dropout_att, dropout_layer, subsample, subsample_type, n_stacks, n_splices, conv_in_channel, conv_channels, conv_kernel_sizes, conv_strides, conv_poolings, conv_batch_norm, conv_layer_norm, conv_bottleneck_dim, conv_param_init, task_specific_layer, param_init, chunk_size_left, chunk_size_current, chunk_size_right, latency_control_type): super(TransformerEncoder, self).__init__() # parse subsample subsamples = [1] * n_layers for lth, s in enumerate(list(map(int, subsample.split('_')[:n_layers]))): subsamples[lth] = s if len(subsamples) > 0 and len(subsamples) != n_layers: raise ValueError( 'subsample must be the same size as n_layers. n_layers: %d, subsample: %s' % (n_layers, subsamples)) if n_layers_sub1 < 0 or (n_layers_sub1 > 1 and n_layers < n_layers_sub1): raise ValueError('Set n_layers_sub1 between 1 to n_layers.') if n_layers_sub2 < 0 or (n_layers_sub2 > 1 and n_layers_sub1 < n_layers_sub2): raise ValueError('Set n_layers_sub2 between 1 to n_layers_sub1.') assert enc_type in ['transformer', 'conv_transformer'] self.d_model = d_model self.n_layers = n_layers self.n_heads = n_heads self.pe_type = pe_type self.scale = math.sqrt(d_model) # for streaming encoder self.chunk_size_left = chunk_size_left self.chunk_size_current = chunk_size_current self.chunk_size_right = chunk_size_right self.latency_controlled = chunk_size_left > 0 or chunk_size_current > 0 or chunk_size_right > 0 self.lc_type = latency_control_type # reshape) not lookahead frames in CNN layers, but requires some additional computations # mask) there are some lookahead frames in CNN layers, no additional computations # TransformerXL like streaming encoder self.memory_transformer = ('transformer_xl' in enc_type) self.mem_len = chunk_size_left if self.memory_transformer: assert pe_type == 'relative' assert chunk_size_left > 0 assert chunk_size_current > 0 # for hierarchical encoder self.n_layers_sub1 = n_layers_sub1 self.n_layers_sub2 = n_layers_sub2 self.task_specific_layer = task_specific_layer # for bridge layers self.bridge = None self.bridge_sub1 = None self.bridge_sub2 = None # for attention plot self.aws_dict = {} self.data_dict = {} # Setting for CNNs if 'conv' in enc_type: assert conv_channels assert n_stacks == 1 and n_splices == 1 self.conv = ConvEncoder(input_dim, in_channel=conv_in_channel, channels=conv_channels, kernel_sizes=conv_kernel_sizes, strides=conv_strides, poolings=conv_poolings, dropout=0., batch_norm=conv_batch_norm, layer_norm=conv_layer_norm, layer_norm_eps=layer_norm_eps, residual=False, bottleneck_dim=d_model, param_init=conv_param_init) self._odim = self.conv.output_dim else: self.conv = None self._odim = input_dim * n_splices * n_stacks self.embed = nn.Linear(self._odim, d_model) # calculate subsampling factor self._factor = 1 if self.conv is not None: self._factor *= self.conv.subsampling_factor self.subsample = None if np.prod(subsamples) > 1: self._factor *= np.prod(subsamples) if subsample_type == 'max_pool': self.subsample = nn.ModuleList( [MaxpoolSubsampler(factor) for factor in subsamples]) elif subsample_type == 'concat': self.subsample = nn.ModuleList([ ConcatSubsampler(factor, self._odim) for factor in subsamples ]) elif subsample_type == 'drop': self.subsample = nn.ModuleList( [DropSubsampler(factor) for factor in subsamples]) elif subsample_type == '1dconv': self.subsample = nn.ModuleList([ Conv1dSubsampler(factor, self._odim) for factor in subsamples ]) if self.chunk_size_left > 0: assert self.chunk_size_left % self._factor == 0 if self.chunk_size_current > 0: assert self.chunk_size_current % self._factor == 0 if self.chunk_size_right > 0: assert self.chunk_size_right % self._factor == 0 self.pos_emb = None self.u = None self.v = None if self.memory_transformer: self.pos_emb = XLPositionalEmbedding(d_model, dropout) self.u = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) self.v = nn.Parameter(torch.Tensor(n_heads, d_model // n_heads)) # NOTE: u and v are global parameters elif pe_type == 'relative': self.pos_emb = XLPositionalEmbedding(d_model, dropout) else: self.pos_enc = PositionalEncoding(d_model, dropout_in, pe_type, param_init) self.layers = nn.ModuleList([ copy.deepcopy( TransformerEncoderBlock(d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, relative_attention=self.pos_emb is not None, ffn_bottleneck_dim=ffn_bottleneck_dim)) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) self._odim = d_model if n_layers_sub1 > 0: if task_specific_layer: self.layer_sub1 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub1 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub1 = nn.Linear(self._odim, last_proj_dim) if n_layers_sub2 > 0: if task_specific_layer: self.layer_sub2 = TransformerEncoderBlock( d_model, d_ff, n_heads, dropout, dropout_att, dropout_layer, layer_norm_eps, ffn_activation, param_init, ffn_bottleneck_dim=ffn_bottleneck_dim) self.norm_out_sub2 = nn.LayerNorm(d_model, eps=layer_norm_eps) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge_sub2 = nn.Linear(self._odim, last_proj_dim) if last_proj_dim > 0 and last_proj_dim != self.output_dim: self.bridge = nn.Linear(self._odim, last_proj_dim) self._odim = last_proj_dim self.reset_parameters(param_init)