def build_encoder(args): # safeguard if not hasattr(args, 'transformer_enc_d_model') and hasattr(args, 'transformer_d_model'): args.transformer_enc_d_model = args.transformer_d_model args.transformer_dec_d_model = args.transformer_d_model if not hasattr(args, 'transformer_enc_d_ff') and hasattr(args, 'transformer_d_ff'): args.transformer_enc_d_ff = args.transformer_d_ff if not hasattr(args, 'transformer_enc_n_heads') and hasattr(args, 'transformer_n_heads'): args.transformer_enc_n_heads = args.transformer_n_heads if args.enc_type == 'tds': from neural_sp.models.seq2seq.encoders.tds import TDSEncoder encoder = TDSEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units) elif args.enc_type == 'gated_conv': from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder raise ValueError encoder = GatedConvEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units, param_init=args.param_init) elif 'transformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_heads=args.transformer_enc_n_heads, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_enc_d_model, d_ff=args.transformer_enc_d_ff, ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim, ffn_activation=args.transformer_ffn_activation, pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, clamp_len=args.transformer_enc_clamp_len, lookahead=args.transformer_enc_lookaheads, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right, streaming_type=args.lc_type) elif 'conformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.conformer import ConformerEncoder encoder = ConformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_heads=args.transformer_enc_n_heads, kernel_size=args.conformer_kernel_size, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_enc_d_model, d_ff=args.transformer_enc_d_ff, ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim, ffn_activation='swish', pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, clamp_len=args.transformer_enc_clamp_len, lookahead=args.transformer_enc_lookaheads, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right, streaming_type=args.lc_type) else: from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, bidir_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd, task_specific_layer=args.task_specific_layer, param_init=args.param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_right=args.lc_chunk_size_right, rsp_prob=args.rsp_prob_enc) return encoder
def select_encoder(args): if 'transformer' in args.enc_type: encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, layer_norm_eps=args.layer_norm_eps, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, param_init=args.param_init) else: subsample = [1] * args.enc_n_layers for l, s in enumerate(list(map(int, args.subsample.split('_')[:args.enc_n_layers]))): subsample[l] = s encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=subsample, subsample_type=args.subsample_type, last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer, param_init=args.param_init) # NOTE: pure Conv/TDS/GatedConv encoders are also included return encoder
def build_encoder(args): if args.enc_type == 'tds': from neural_sp.models.seq2seq.encoders.tds import TDSEncoder raise ValueError encoder = TDSEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, bottleneck_dim=args.transformer_d_model if 'transformer' in args.dec_type else args.dec_n_units) elif args.enc_type == 'gated_conv': from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder raise ValueError encoder = GatedConvEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, bottleneck_dim=args.transformer_d_model if 'transformer' in args.dec_type else args.dec_n_units, param_init=args.param_init) elif 'transformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, attn_type=args.transformer_attn_type, n_heads=args.transformer_n_heads, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_d_model, d_ff=args.transformer_d_ff, last_proj_dim=args.transformer_d_model if 'transformer' in args.dec_type else 0, pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, ffn_activation=args.transformer_ffn_activation, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right) else: subsample = [1] * args.enc_n_layers for l, s in enumerate( list(map(int, args.subsample.split('_')[:args.enc_n_layers]))): subsample[l] = s from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, last_proj_dim=args.transformer_d_model if 'transformer' in args.dec_type else 0, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, bidirectional_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd, task_specific_layer=args.task_specific_layer, param_init=args.param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_right=args.lc_chunk_size_right) # NOTE: pure Conv/TDS/GatedConv encoders are also included return encoder
def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type self.input_dim = args.input_dim self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru']: self.enc_n_units *= 2 self.bridge_layer = args.bridge_layer # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for attention layer self.attn_n_heads = args.attn_n_heads # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0) # Encoder if args.enc_type == 'transformer': self.enc = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, # pe_type=args.pe_type, pe_type=False, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, layer_norm_eps=args.layer_norm_eps, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim) else: self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=list(map(int, args.subsample.split('_'))) + [1] * (args.enc_n_layers - len(args.subsample.split('_'))), subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer) # NOTE: pure CNN/TDS encoders are also included if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # Bridge layer between the encoder and decoder self.is_bridge = False if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer'] and args.ctc_weight < 1 ) or args.dec_type == 'transformer' or args.bridge_layer: self.bridge = LinearND(self.enc.output_dim, args.d_model if args.dec_type == 'transformer' else args.dec_n_units, dropout=args.dropout_enc) self.is_bridge = True if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) self.enc_n_units = args.dec_n_units # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Cold fusion if args.lm_fusion and dir == 'fwd': lm = RNNLM(args.lm_conf) lm, _ = load_checkpoint(lm, args.lm_fusion) else: args.lm_conf = False lm = None # TODO(hirofumi): cold fusion for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, layer_norm_eps=args.layer_norm_eps, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), # lm=args.lm_conf, lm=lm, # TODO(hirofumi): load RNNLM in the model init. lm_fusion_type=args.lm_fusion_type, contextualize=args.contextualize, lm_init=args.lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed_in = dec.embed else: self.embed_in = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize parameters in CNN layers self.reset_parameters( args.param_init, # dist='xavier_uniform', # dist='kaiming_uniform', dist='lecun', keys=['conv'], ignore_keys=['score']) # Initialize parameters in the encoder if args.enc_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['enc'], ignore_keys=['embed_in']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed_in']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['enc'], ignore_keys=['conv']) # Initialize parameters in the decoder if args.dec_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['dec'], ignore_keys=['embed']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['dec']) # Initialize bias vectors with zero self.reset_parameters(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 for cold fusion if args.lm_fusion: self.reset_parameters(-1, dist='constant', keys=['linear_lm_gate.fc.bias']) if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False