def build_encoder(args): if args.enc_type == 'tds': from neural_sp.models.seq2seq.encoders.tds import TDSEncoder raise ValueError encoder = TDSEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, bottleneck_dim=args.transformer_d_model if 'transformer' in args.dec_type else args.dec_n_units) elif args.enc_type == 'gated_conv': from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder raise ValueError encoder = GatedConvEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, bottleneck_dim=args.transformer_d_model if 'transformer' in args.dec_type else args.dec_n_units, param_init=args.param_init) elif 'transformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, attn_type=args.transformer_attn_type, n_heads=args.transformer_n_heads, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_d_model, d_ff=args.transformer_d_ff, last_proj_dim=args.transformer_d_model if 'transformer' in args.dec_type else 0, pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, ffn_activation=args.transformer_ffn_activation, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right) else: subsample = [1] * args.enc_n_layers for l, s in enumerate( list(map(int, args.subsample.split('_')[:args.enc_n_layers]))): subsample[l] = s from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, last_proj_dim=args.transformer_d_model if 'transformer' in args.dec_type else 0, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, bidirectional_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd, task_specific_layer=args.task_specific_layer, param_init=args.param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_right=args.lc_chunk_size_right) # NOTE: pure Conv/TDS/GatedConv encoders are also included return encoder
def build_encoder(args): # safeguard if not hasattr(args, 'transformer_enc_d_model') and hasattr(args, 'transformer_d_model'): args.transformer_enc_d_model = args.transformer_d_model args.transformer_dec_d_model = args.transformer_d_model if not hasattr(args, 'transformer_enc_d_ff') and hasattr(args, 'transformer_d_ff'): args.transformer_enc_d_ff = args.transformer_d_ff if not hasattr(args, 'transformer_enc_n_heads') and hasattr(args, 'transformer_n_heads'): args.transformer_enc_n_heads = args.transformer_n_heads if args.enc_type == 'tds': from neural_sp.models.seq2seq.encoders.tds import TDSEncoder encoder = TDSEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units) elif args.enc_type == 'gated_conv': from neural_sp.models.seq2seq.encoders.gated_conv import GatedConvEncoder raise ValueError encoder = GatedConvEncoder( input_dim=args.input_dim * args.n_stacks, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, dropout=args.dropout_enc, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else args.dec_n_units, param_init=args.param_init) elif 'transformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.transformer import TransformerEncoder encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_heads=args.transformer_enc_n_heads, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_enc_d_model, d_ff=args.transformer_enc_d_ff, ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim, ffn_activation=args.transformer_ffn_activation, pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, clamp_len=args.transformer_enc_clamp_len, lookahead=args.transformer_enc_lookaheads, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right, streaming_type=args.lc_type) elif 'conformer' in args.enc_type: from neural_sp.models.seq2seq.encoders.conformer import ConformerEncoder encoder = ConformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_heads=args.transformer_enc_n_heads, kernel_size=args.conformer_kernel_size, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, d_model=args.transformer_enc_d_model, d_ff=args.transformer_enc_d_ff, ffn_bottleneck_dim=args.transformer_ffn_bottleneck_dim, ffn_activation='swish', pe_type=args.transformer_enc_pe_type, layer_norm_eps=args.transformer_layer_norm_eps, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, dropout_layer=args.dropout_enc_layer, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, conv_param_init=args.param_init, task_specific_layer=args.task_specific_layer, param_init=args.transformer_param_init, clamp_len=args.transformer_enc_clamp_len, lookahead=args.transformer_enc_lookaheads, chunk_size_left=args.lc_chunk_size_left, chunk_size_current=args.lc_chunk_size_current, chunk_size_right=args.lc_chunk_size_right, streaming_type=args.lc_type) else: from neural_sp.models.seq2seq.encoders.rnn import RNNEncoder encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, enc_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, last_proj_dim=args.transformer_dec_d_model if 'transformer' in args.dec_type else 0, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=args.subsample, subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_layer_norm=args.conv_layer_norm, conv_bottleneck_dim=args.conv_bottleneck_dim, bidir_sum_fwd_bwd=args.bidirectional_sum_fwd_bwd, task_specific_layer=args.task_specific_layer, param_init=args.param_init, chunk_size_left=args.lc_chunk_size_left, chunk_size_right=args.lc_chunk_size_right, rsp_prob=args.rsp_prob_enc) return encoder
def select_encoder(args): if 'transformer' in args.enc_type: encoder = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, layer_norm_eps=args.layer_norm_eps, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, param_init=args.param_init) else: subsample = [1] * args.enc_n_layers for l, s in enumerate(list(map(int, args.subsample.split('_')[:args.enc_n_layers]))): subsample[l] = s encoder = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=subsample, subsample_type=args.subsample_type, last_proj_dim=args.d_model if 'transformer' in args.dec_type else args.dec_n_units, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer, param_init=args.param_init) # NOTE: pure Conv/TDS/GatedConv encoders are also included return encoder
def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type self.input_dim = args.input_dim self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru']: self.enc_n_units *= 2 self.bridge_layer = args.bridge_layer # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for attention layer self.attn_n_heads = args.attn_n_heads # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0) # Encoder if args.enc_type == 'transformer': self.enc = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, # pe_type=args.pe_type, pe_type=False, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, layer_norm_eps=args.layer_norm_eps, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim) else: self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=list(map(int, args.subsample.split('_'))) + [1] * (args.enc_n_layers - len(args.subsample.split('_'))), subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer) # NOTE: pure CNN/TDS encoders are also included if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # Bridge layer between the encoder and decoder self.is_bridge = False if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer'] and args.ctc_weight < 1 ) or args.dec_type == 'transformer' or args.bridge_layer: self.bridge = LinearND(self.enc.output_dim, args.d_model if args.dec_type == 'transformer' else args.dec_n_units, dropout=args.dropout_enc) self.is_bridge = True if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) self.enc_n_units = args.dec_n_units # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Cold fusion if args.lm_fusion and dir == 'fwd': lm = RNNLM(args.lm_conf) lm, _ = load_checkpoint(lm, args.lm_fusion) else: args.lm_conf = False lm = None # TODO(hirofumi): cold fusion for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, layer_norm_eps=args.layer_norm_eps, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), # lm=args.lm_conf, lm=lm, # TODO(hirofumi): load RNNLM in the model init. lm_fusion_type=args.lm_fusion_type, contextualize=args.contextualize, lm_init=args.lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed_in = dec.embed else: self.embed_in = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize parameters in CNN layers self.reset_parameters( args.param_init, # dist='xavier_uniform', # dist='kaiming_uniform', dist='lecun', keys=['conv'], ignore_keys=['score']) # Initialize parameters in the encoder if args.enc_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['enc'], ignore_keys=['embed_in']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed_in']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['enc'], ignore_keys=['conv']) # Initialize parameters in the decoder if args.dec_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['dec'], ignore_keys=['embed']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['dec']) # Initialize bias vectors with zero self.reset_parameters(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 for cold fusion if args.lm_fusion: self.reset_parameters(-1, dist='constant', keys=['linear_lm_gate.fc.bias']) if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False
class Seq2seq(ModelBase): """Attention-based RNN sequence-to-sequence model (including CTC).""" def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type self.input_dim = args.input_dim self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru']: self.enc_n_units *= 2 self.bridge_layer = args.bridge_layer # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for attention layer self.attn_n_heads = args.attn_n_heads # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0) # Encoder if args.enc_type == 'transformer': self.enc = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, # pe_type=args.pe_type, pe_type=False, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, layer_norm_eps=args.layer_norm_eps, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim) else: self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=list(map(int, args.subsample.split('_'))) + [1] * (args.enc_n_layers - len(args.subsample.split('_'))), subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer) # NOTE: pure CNN/TDS encoders are also included if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # Bridge layer between the encoder and decoder self.is_bridge = False if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer'] and args.ctc_weight < 1 ) or args.dec_type == 'transformer' or args.bridge_layer: self.bridge = LinearND(self.enc.output_dim, args.d_model if args.dec_type == 'transformer' else args.dec_n_units, dropout=args.dropout_enc) self.is_bridge = True if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) self.enc_n_units = args.dec_n_units # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Cold fusion if args.lm_fusion and dir == 'fwd': lm = RNNLM(args.lm_conf) lm, _ = load_checkpoint(lm, args.lm_fusion) else: args.lm_conf = False lm = None # TODO(hirofumi): cold fusion for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, layer_norm_eps=args.layer_norm_eps, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), # lm=args.lm_conf, lm=lm, # TODO(hirofumi): load RNNLM in the model init. lm_fusion_type=args.lm_fusion_type, contextualize=args.contextualize, lm_init=args.lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed_in = dec.embed else: self.embed_in = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize parameters in CNN layers self.reset_parameters( args.param_init, # dist='xavier_uniform', # dist='kaiming_uniform', dist='lecun', keys=['conv'], ignore_keys=['score']) # Initialize parameters in the encoder if args.enc_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['enc'], ignore_keys=['embed_in']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed_in']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['enc'], ignore_keys=['conv']) # Initialize parameters in the decoder if args.dec_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['dec'], ignore_keys=['embed']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['dec']) # Initialize bias vectors with zero self.reset_parameters(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 for cold fusion if args.lm_fusion: self.reset_parameters(-1, dist='constant', keys=['linear_lm_gate.fc.bias']) if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False def scheduled_sampling_trigger(self): # main task directions = [] if self.fwd_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: getattr(self, 'dec_' + dir).start_scheduled_sampling() # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: directions = [] if getattr(self, 'fwd_weight_' + sub) > 0: directions.append('fwd') for dir_sub in directions: getattr(self, 'dec_' + dir_sub + '_' + sub).start_scheduled_sampling() def forward(self, batch, reporter=None, task='all', is_eval=False): """Forward computation. Args: batch (dict): xs (list): input data of size `[T, input_dim]` xlens (list): lengths of each element in xs ys (list): reference labels in the main task of size `[L]` ys_sub1 (list): reference labels in the 1st auxiliary task of size `[L_sub1]` ys_sub2 (list): reference labels in the 2nd auxiliary task of size `[L_sub2]` utt_ids (list): name of utterances speakers (list): name of speakers reporter (): task (str): all or ys* or ys_sub* is_eval (bool): the history will not be saved. This should be used in inference model for memory efficiency. Returns: loss (FloatTensor): `[1]` reporter (): """ if is_eval: self.eval() with torch.no_grad(): loss, reporter = self._forward(batch, task, reporter) else: self.train() loss, reporter = self._forward(batch, task, reporter) return loss, reporter def _forward(self, batch, task, reporter): # Encode input features if self.input_type == 'speech': if self.mtl_per_batch: flip = True if 'bwd' in task else False enc_outs = self.encode(batch['xs'], task, flip=flip) else: flip = True if self.bwd_weight == 1 else False enc_outs = self.encode(batch['xs'], 'all', flip=flip) else: enc_outs = self.encode(batch['ys_sub1']) observation = {} loss = torch.zeros((1, ), dtype=torch.float32).cuda(self.device_id) # for the forward decoder in the main task if (self.fwd_weight > 0 or self.ctc_weight > 0) and task in [ 'all', 'ys', 'ys.ctc', 'ys.lmobj' ]: loss_fwd, obs_fwd = self.dec_fwd(enc_outs['ys']['xs'], enc_outs['ys']['xlens'], batch['ys'], task, batch['ys_hist']) loss += loss_fwd observation['loss.att'] = obs_fwd['loss_att'] observation['loss.ctc'] = obs_fwd['loss_ctc'] observation['loss.lmobj'] = obs_fwd['loss_lmobj'] observation['acc.att'] = obs_fwd['acc_att'] observation['acc.lmobj'] = obs_fwd['acc_lmobj'] observation['ppl.att'] = obs_fwd['ppl_att'] observation['ppl.lmobj'] = obs_fwd['ppl_lmobj'] # for the backward decoder in the main task if self.bwd_weight > 0 and task in ['all', 'ys.bwd']: loss_bwd, obs_bwd = self.dec_bwd(enc_outs['ys']['xs'], enc_outs['ys']['xlens'], batch['ys'], task) loss += loss_bwd observation['loss.att-bwd'] = obs_bwd['loss_att'] observation['loss.ctc-bwd'] = obs_bwd['loss_ctc'] observation['loss.lmobj-bwd'] = obs_bwd['loss_lmobj'] observation['acc.att-bwd'] = obs_bwd['acc_att'] observation['acc.lmobj-bwd'] = obs_bwd['acc_lmobj'] observation['ppl.att-bwd'] = obs_bwd['ppl_att'] observation['ppl.lmobj-bwd'] = obs_bwd['ppl_lmobj'] # only fwd for sub tasks for sub in ['sub1', 'sub2']: # for the forward decoder in the sub tasks if (getattr(self, 'fwd_weight_' + sub) > 0 or getattr(self, 'ctc_weight_' + sub) > 0) and task in [ 'all', 'ys_' + sub, 'ys_' + sub + '.ctc', 'ys_' + sub + '.lmobj' ]: loss_sub, obs_fwd_sub = getattr(self, 'dec_fwd_' + sub)( enc_outs['ys_' + sub]['xs'], enc_outs['ys_' + sub]['xlens'], batch['ys_' + sub], task) loss += loss_sub observation['loss.att-' + sub] = obs_fwd_sub['loss_att'] observation['loss.ctc-' + sub] = obs_fwd_sub['loss_ctc'] observation['loss.lmobj-' + sub] = obs_fwd_sub['loss_lmobj'] observation['acc.att-' + sub] = obs_fwd_sub['acc_att'] observation['acc.lmobj-' + sub] = obs_fwd_sub['acc_lmobj'] observation['ppl.att-' + sub] = obs_fwd_sub['ppl_att'] observation['ppl.lmobj-' + sub] = obs_fwd_sub['ppl_lmobj'] if reporter is not None: is_eval = not self.training reporter.add(observation, is_eval) return loss, reporter def encode(self, xs, task='all', flip=False): """Encode acoustic or text features. Args: xs (list): A list of length `[B]`, which contains Tensor of size `[T, input_dim]` task (str): all or ys* or ys_sub1* or ys_sub2* flip (bool): if True, flip acoustic features in the time-dimension Returns: enc_outs (dict): """ if 'lmobj' in task: eouts = { 'ys': { 'xs': None, 'xlens': None }, 'ys_sub1': { 'xs': None, 'xlens': None }, 'ys_sub2': { 'xs': None, 'xlens': None } } return eouts else: if self.input_type == 'speech': # Frame stacking if self.n_stacks > 1: xs = [ stack_frame(x, self.n_stacks, self.n_skips) for x in xs ] # Splicing if self.n_splices > 1: xs = [splice(x, self.n_splices, self.n_stacks) for x in xs] xlens = [len(x) for x in xs] # Flip acoustic features in the reverse order if flip: xs = [ torch.from_numpy(np.flip( x, axis=0).copy()).float().cuda(self.device_id) for x in xs ] else: xs = [np2tensor(x, self.device_id).float() for x in xs] xs = pad_list(xs, 0.0) elif self.input_type == 'text': xlens = [len(x) for x in xs] xs = [ np2tensor(np.fromiter(x, dtype=np.int64), self.device_id).long() for x in xs ] xs = pad_list(xs, self.pad) xs = self.embed_in(xs) # sequence summary network if self.ssn is not None: xs += self.ssn(xs, xlens) # encoder enc_outs = self.enc(xs, xlens, task.split('.')[0]) if self.main_weight < 1 and self.enc_type in [ 'conv', 'tds', 'gated_conv', 'transformer' ]: for sub in ['sub1', 'sub2']: enc_outs['ys_' + sub]['xs'] = enc_outs['ys']['xs'].clone() enc_outs['ys_' + sub]['xlens'] = enc_outs['ys']['xlens'][:] # Bridge between the encoder and decoder if self.main_weight > 0 and self.is_bridge: enc_outs['ys']['xs'] = self.bridge(enc_outs['ys']['xs']) if self.sub1_weight > 0 and self.is_bridge: enc_outs['ys_sub1']['xs'] = self.bridge_sub1( enc_outs['ys_sub1']['xs']) if self.sub2_weight > 0 and self.is_bridge: enc_outs['ys_sub2']['xs'] = self.bridge_sub2( enc_outs['ys_sub2']['xs']) return enc_outs def get_ctc_probs(self, xs, task='ys', temperature=1, topk=None): self.eval() with torch.no_grad(): enc_outs = self.encode(xs, task) dir = 'fwd' if self.fwd_weight >= self.bwd_weight else 'bwd' if task == 'ys_sub1': dir += '_sub1' elif task == 'ys_sub2': dir += '_sub2' if task == 'ys': assert self.ctc_weight > 0 elif task == 'ys_sub1': assert self.ctc_weight_sub1 > 0 elif task == 'ys_sub2': assert self.ctc_weight_sub2 > 0 ctc_probs, indices_topk = getattr(self, 'dec_' + dir).ctc_probs_topk( enc_outs[task]['xs'], temperature, topk) return ctc_probs, indices_topk, enc_outs[task]['xlens'] def decode(self, xs, params, idx2token, nbest=1, exclude_eos=False, refs_id=None, refs_text=None, utt_ids=None, speakers=None, task='ys', ensemble_models=[]): """Decoding in the inference stage. Args: xs (list): A list of length `[B]`, which contains arrays of size `[T, input_dim]` params (dict): hyper-parameters for decoding beam_width (int): the size of beam min_len_ratio (float): max_len_ratio (float): len_penalty (float): length penalty cov_penalty (float): coverage penalty cov_threshold (float): threshold for coverage penalty lm_weight (float): the weight of RNNLM score resolving_unk (bool): not used (to make compatible) fwd_bwd_attention (bool): idx2token (): converter from index to token nbest (int): exclude_eos (bool): exclude <eos> from best_hyps_id refs_id (list): gold token IDs to compute log likelihood refs_text (list): gold transcriptions utt_ids (list): speakers (list): task (str): ys* or ys_sub1* or ys_sub2* ensemble_models (list): list of Seq2seq classes Returns: best_hyps_id (list): A list of length `[B]`, which contains arrays of size `[L]` aws (list): A list of length `[B]`, which contains arrays of size `[L, T, n_heads]` """ self.eval() with torch.no_grad(): if task.split('.')[0] == 'ys': dir = 'bwd' if self.bwd_weight > 0 and params[ 'recog_bwd_attention'] else 'fwd' elif task.split('.')[0] == 'ys_sub1': dir = 'fwd_sub1' elif task.split('.')[0] == 'ys_sub2': dir = 'fwd_sub2' else: raise ValueError(task) # encode if self.input_type == 'speech' and self.mtl_per_batch and 'bwd' in dir: enc_outs = self.encode(xs, task, flip=True) else: enc_outs = self.encode(xs, task, flip=False) ######################### # CTC ######################### if (self.fwd_weight == 0 and self.bwd_weight == 0) or ( self.ctc_weight > 0 and params['recog_ctc_weight'] == 1): lm = None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm = getattr(self, 'lm_' + dir) best_hyps_id = getattr(self, 'dec_' + dir).decode_ctc( enc_outs[task]['xs'], enc_outs[task]['xlens'], params['recog_beam_width'], lm, params['recog_lm_weight']) return best_hyps_id, None, (None, None) ######################### # Attention ######################### else: cache_info = (None, None) if params['recog_beam_width'] == 1 and not params[ 'recog_fwd_bwd_attention']: best_hyps_id, aws = getattr(self, 'dec_' + dir).greedy( enc_outs[task]['xs'], enc_outs[task]['xlens'], params['recog_max_len_ratio'], exclude_eos, idx2token, refs_id, speakers, params['recog_oracle']) else: assert params['recog_batch_size'] == 1 ctc_log_probs = None if params['recog_ctc_weight'] > 0: ctc_log_probs = self.dec_fwd.ctc_log_probs( enc_outs[task]['xs']) # forward-backward decoding if params['recog_fwd_bwd_attention']: # forward decoder lm_fwd, lm_bwd = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm_fwd = self.lm_fwd if params['recog_reverse_lm_rescoring'] and hasattr( self, 'lm_bwd') and self.lm_bwd is not None: lm_bwd = self.lm_bwd # ensemble (forward) ensmbl_eouts_fwd = [] ensmbl_elens_fwd = [] ensmbl_decs_fwd = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): enc_outs_e_fwd = model.encode(xs, task, flip=False) ensmbl_eouts_fwd += [ enc_outs_e_fwd[task]['xs'] ] ensmbl_elens_fwd += [ enc_outs_e_fwd[task]['xlens'] ] ensmbl_decs_fwd += [model.dec_fwd] # NOTE: only support for the main task now nbest_hyps_id_fwd, aws_fwd, scores_fwd, cache_info = self.dec_fwd.beam_search( enc_outs[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm_fwd, lm_bwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers, ensmbl_eouts_fwd, ensmbl_elens_fwd, ensmbl_decs_fwd) # backward decoder lm_bwd, lm_fwd = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_bwd') and self.lm_bwd is not None: lm_bwd = self.lm_bwd if params['recog_reverse_lm_rescoring'] and hasattr( self, 'lm_fwd') and self.lm_fwd is not None: lm_fwd = self.lm_fwd # ensemble (backward) ensmbl_eouts_bwd = [] ensmbl_elens_bwd = [] ensmbl_decs_bwd = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): if self.input_type == 'speech' and self.mtl_per_batch: enc_outs_e_bwd = model.encode(xs, task, flip=True) else: enc_outs_e_bwd = model.encode(xs, task, flip=False) ensmbl_eouts_bwd += [ enc_outs_e_bwd[task]['xs'] ] ensmbl_elens_bwd += [ enc_outs_e_bwd[task]['xlens'] ] ensmbl_decs_bwd += [model.dec_bwd] # NOTE: only support for the main task now # TODO(hirofumi): merge with the forward for the efficiency flip = False if self.input_type == 'speech' and self.mtl_per_batch: flip = True enc_outs_bwd = self.encode(xs, task, flip=True) else: enc_outs_bwd = enc_outs nbest_hyps_id_bwd, aws_bwd, scores_bwd, _ = self.dec_bwd.beam_search( enc_outs_bwd[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm_bwd, lm_fwd, ctc_log_probs, params['recog_beam_width'], False, refs_id, utt_ids, speakers, ensmbl_eouts_bwd, ensmbl_elens_bwd, ensmbl_decs_bwd) # forward-backward attention best_hyps_id = fwd_bwd_attention( nbest_hyps_id_fwd, aws_fwd, scores_fwd, nbest_hyps_id_bwd, aws_bwd, scores_bwd, flip, self.eos, params['recog_gnmt_decoding'], params['recog_length_penalty'], idx2token, refs_id) aws = None else: # ensemble ensmbl_eouts = [] ensmbl_elens = [] ensmbl_decs = [] if len(ensemble_models) > 0: for i_e, model in enumerate(ensemble_models): if model.input_type == 'speech' and model.mtl_per_batch and 'bwd' in dir: enc_outs_e = model.encode(xs, task, flip=True) else: enc_outs_e = model.encode(xs, task, flip=False) ensmbl_eouts += [enc_outs_e[task]['xs']] ensmbl_elens += [enc_outs_e[task]['xlens']] ensmbl_decs += [getattr(model, 'dec_' + dir)] # NOTE: only support for the main task now lm, lm_rev = None, None if params['recog_lm_weight'] > 0 and hasattr( self, 'lm_' + dir) and getattr( self, 'lm_' + dir) is not None: lm = getattr(self, 'lm_' + dir) if params['recog_reverse_lm_rescoring']: if dir == 'fwd': lm_rev = self.lm_bwd else: raise NotImplementedError nbest_hyps_id, aws, scores, cache_info = getattr( self, 'dec_' + dir).beam_search( enc_outs[task]['xs'], enc_outs[task]['xlens'], params, idx2token, lm, lm_rev, ctc_log_probs, nbest, exclude_eos, refs_id, utt_ids, speakers, ensmbl_eouts, ensmbl_elens, ensmbl_decs) if nbest == 1: best_hyps_id = [hyp[0] for hyp in nbest_hyps_id] aws = [aw[0] for aw in aws] else: return nbest_hyps_id, aws, scores, cache_info # NOTE: nbest >= 2 is used for MWER training only return best_hyps_id, aws, cache_info
def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type assert args.input_type in ['speech', 'text'] self.input_dim = args.input_dim self.num_stack = args.num_stack self.num_skip = args.num_skip self.num_splice = args.num_splice self.enc_type = args.enc_type self.enc_num_units = args.enc_num_units if args.enc_type in ['blstm', 'bgru']: self.enc_num_units *= 2 # for attention layer self.att_num_heads_0 = args.att_num_heads self.att_num_heads_1 = args.att_num_heads_sub self.share_attention = False # for decoder self.num_classes = args.num_classes self.num_classes_sub = args.num_classes_sub self.blank = 0 self.unk = 1 self.sos = 2 self.eos = 3 self.pad = 4 # NOTE: these are reserved in advance # for CTC self.ctc_weight_0 = args.ctc_weight self.ctc_weight_1 = args.ctc_weight_sub # for backward decoder assert 0 <= args.bwd_weight <= 1 assert 0 <= args.bwd_weight_sub <= 1 self.fwd_weight_0 = 1 - args.bwd_weight self.bwd_weight_0 = args.bwd_weight self.fwd_weight_1 = 1 - args.bwd_weight self.bwd_weight_1 = args.bwd_weight # for the sub task self.main_task_weight = args.main_task_weight # Encoder if args.enc_type in ['blstm', 'lstm', 'bgru', 'gru']: self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, num_units=args.enc_num_units, num_projs=args.enc_num_projs, num_layers=args.enc_num_layers, num_layers_sub=args.enc_num_layers_sub, dropout_in=args.dropout_in, dropout_hidden=args.dropout_enc, subsample=args.subsample, subsample_type=args.subsample_type, batch_first=True, num_stack=args.num_stack, num_splice=args.num_splice, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, residual=args.enc_residual, nin=0, num_projs_final=args.dec_num_units if args.bridge_layer else 0) elif args.enc_type == 'cnn': assert args.num_stack == 1 and args.num_splice == 1 self.enc = CNNEncoder(input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, in_channel=args.conv_in_channel, channels=args.conv_channels, kernel_sizes=args.conv_kernel_sizes, strides=args.conv_strides, poolings=args.conv_poolings, dropout_in=args.dropout_in, dropout_hidden=args.dropout_enc, num_projs_final=args.dec_num_units, batch_norm=args.conv_batch_norm) else: raise NotImplementedError() # Bridge layer between the encoder and decoder if args.enc_type == 'cnn': self.enc_num_units = args.dec_num_units elif args.bridge_layer: self.bridge_0 = LinearND(self.enc_num_units, args.dec_num_units) self.enc_num_units = args.dec_num_units else: self.bridge_0 = lambda x: x directions = [] if self.fwd_weight_0 > 0: directions.append('fwd') if self.bwd_weight_0 > 0: directions.append('bwd') for dir in directions: if args.ctc_weight < 1: # Attention layer if args.att_num_heads > 1: attention = MultiheadAttentionMechanism( enc_num_units=self.enc_num_units, dec_num_units=args.dec_num_units, att_type=args.att_type, att_dim=args.att_dim, sharpening_factor=args.att_sharpening_factor, sigmoid_smoothing=args.att_sigmoid_smoothing, conv_out_channels=args.att_conv_num_channels, conv_kernel_size=args.att_conv_width, num_heads=args.att_num_heads) else: attention = AttentionMechanism( enc_num_units=self.enc_num_units, dec_num_units=args.dec_num_units, att_type=args.att_type, att_dim=args.att_dim, sharpening_factor=args.att_sharpening_factor, sigmoid_smoothing=args.att_sigmoid_smoothing, conv_out_channels=args.att_conv_num_channels, conv_kernel_size=args.att_conv_width) # Cold fusion # if args.rnnlm_cf is not None and dir == 'fwd': # raise NotImplementedError() # # TODO(hirofumi): cold fusion for backward RNNLM # else: # args.rnnlm_cf = None # # # RNNLM initialization # if args.rnnlm_config_init is not None and dir == 'fwd': # raise NotImplementedError() # # TODO(hirofumi): RNNLM initialization for backward RNNLM # else: # args.rnnlm_init = None else: attention = None # Decoder decoder = Decoder( attention=attention, sos=self.sos, eos=self.eos, pad=self.pad, enc_num_units=self.enc_num_units, rnn_type=args.dec_type, num_units=args.dec_num_units, num_layers=args.dec_num_layers, residual=args.dec_residual, emb_dim=args.emb_dim, num_classes=self.num_classes, logits_temp=args.logits_temp, dropout_dec=args.dropout_dec, dropout_emb=args.dropout_emb, ss_prob=args.ss_prob, lsm_prob=args.lsm_prob, lsm_type=args.lsm_type, init_with_enc=args.init_with_enc, ctc_weight=args.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=args.ctc_fc_list, backward=(dir == 'bwd'), rnnlm_cf=args.rnnlm_cf, cold_fusion_type=args.cold_fusion_type, internal_lm=args.internal_lm, rnnlm_init=args.rnnlm_init, # rnnlm_weight=args.rnnlm_weight, share_softmax=args.share_softmax) setattr(self, 'dec_' + dir + '_0', decoder) # NOTE: fwd only for the sub task if args.main_task_weight < 1: if args.ctc_weight_sub < 1: # Attention layer if args.att_num_heads_sub > 1: attention_sub = MultiheadAttentionMechanism( enc_num_units=self.enc_num_units, dec_num_units=args.dec_num_units, att_type=args.att_type, att_dim=args.att_dim, sharpening_factor=args.att_sharpening_factor, sigmoid_smoothing=args.att_sigmoid_smoothing, conv_out_channels=args.att_conv_num_channels, conv_kernel_size=args.att_conv_width, num_heads=args.att_num_heads_sub) else: attention_sub = AttentionMechanism( enc_num_units=self.enc_num_units, dec_num_units=args.dec_num_units, att_type=args.att_type, att_dim=args.att_dim, sharpening_factor=args.att_sharpening_factor, sigmoid_smoothing=args.att_sigmoid_smoothing, conv_out_channels=args.att_conv_num_channels, conv_kernel_size=args.att_conv_width) else: attention_sub = None # Decoder self.dec_fwd_1 = Decoder(attention=attention_sub, sos=self.sos, eos=self.eos, pad=self.pad, enc_num_units=self.enc_num_units, rnn_type=args.dec_type, num_units=args.dec_num_units, num_layers=args.dec_num_layers, residual=args.dec_residual, emb_dim=args.emb_dim, num_classes=self.num_classes_sub, logits_temp=args.logits_temp, dropout_dec=args.dropout_dec, dropout_emb=args.dropout_emb, ss_prob=args.ss_prob, lsm_prob=args.lsm_prob, lsm_type=args.lsm_type, init_with_enc=args.init_with_enc, ctc_weight=args.ctc_weight_sub, ctc_fc_list=args.ctc_fc_list) # sub?? if args.input_type == 'text': if args.num_classes == args.num_classes_sub: # Share the embedding layer between input and output self.embed_in = decoder.emb else: self.embed_in = Embedding(num_classes=args.num_classes_sub, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize weight matrices self.init_weights(args.param_init, dist=args.param_init_dist, ignore_keys=['bias']) # Initialize all biases with 0 self.init_weights(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: # encoder if args.enc_type != 'cnn': self.init_weights(args.param_init, dist='orthogonal', keys=[args.enc_type, 'weight'], ignore_keys=['bias']) # TODO(hirofumi): in case of CNN + LSTM # decoder self.init_weights(args.param_init, dist='orthogonal', keys=[args.dec_type, 'weight'], ignore_keys=['bias']) # Initialize bias in forget gate with 1 self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 if args.rnnlm_cf is not None: self.init_weights(-1, dist='constant', keys=['cf_fc_lm_gate.fc.bias'])
def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type self.input_dim = args.input_dim self.nstacks = args.nstacks self.nskips = args.nskips self.nsplices = args.nsplices self.enc_type = args.enc_type self.enc_nunits = args.enc_nunits if args.enc_type in ['blstm', 'bgru']: self.enc_nunits *= 2 self.bridge_layer = args.bridge_layer # for attention layer self.attn_nheads = args.attn_nheads # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.sos = 2 # NOTE: the same index as <eos> self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for CTC self.ctc_weight = args.ctc_weight self.ctc_weight_sub1 = args.ctc_weight_sub1 self.ctc_weight_sub2 = args.ctc_weight_sub2 # for backward decoder self.fwd_weight = 1 - args.bwd_weight self.fwd_weight_sub1 = 1 - args.bwd_weight_sub1 self.fwd_weight_sub2 = 1 - args.bwd_weight_sub2 self.bwd_weight = args.bwd_weight self.bwd_weight_sub1 = args.bwd_weight_sub1 self.bwd_weight_sub2 = args.bwd_weight_sub2 # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # Setting for the CNN encoder if args.conv_poolings: conv_channels = [int(c) for c in args.conv_channels.split('_')] if len(args.conv_channels) > 0 else [] conv_kernel_sizes = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))] for c in args.conv_kernel_sizes.split('_')] if len(args.conv_kernel_sizes) > 0 else [] conv_strides = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))] for c in args.conv_strides.split('_')] if len(args.conv_strides) > 0 else [] conv_poolings = [[int(c.split(',')[0].replace('(', '')), int(c.split(',')[1].replace(')', ''))] for c in args.conv_poolings.split('_')] if len(args.conv_poolings) > 0 else [] else: conv_channels = [] conv_kernel_sizes = [] conv_strides = [] conv_poolings = [] # Encoder self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, nunits=args.enc_nunits, nprojs=args.enc_nprojs, nlayers=args.enc_nlayers, nlayers_sub1=args.enc_nlayers_sub1, nlayers_sub2=args.enc_nlayers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=[int(s) for s in args.subsample.split('_')], subsample_type=args.subsample_type, nstacks=args.nstacks, nsplices=args.nsplices, conv_in_channel=args.conv_in_channel, conv_channels=conv_channels, conv_kernel_sizes=conv_kernel_sizes, conv_strides=conv_strides, conv_poolings=conv_poolings, conv_batch_norm=args.conv_batch_norm, residual=args.enc_residual, nin=0, layer_norm=args.layer_norm, task_specific_layer=args.task_specific_layer and args.ctc_weight > 0, task_specific_layer_sub1=args.task_specific_layer, task_specific_layer_sub2=args.task_specific_layer) # Bridge layer between the encoder and decoder if args.enc_type == 'cnn': self.bridge = LinearND(self.enc.conv.output_dim, args.dec_nunits, dropout=args.dropout_enc) if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc.conv.output_dim, args.dec_nunits, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc.conv.output_dim, args.dec_nunits, dropout=args.dropout_enc) self.enc_nunits = args.dec_nunits elif args.bridge_layer: self.bridge = LinearND(self.enc_nunits, args.dec_nunits, dropout=args.dropout_enc) if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc_nunits, args.dec_nunits, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc_nunits, args.dec_nunits, dropout=args.dropout_enc) self.enc_nunits = args.dec_nunits # MAIN TASK directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Cold fusion if args.rnnlm_cold_fusion and dir == 'fwd': logger.inof('cold fusion') raise NotImplementedError() # TODO(hirofumi): cold fusion for backward RNNLM else: args.rnnlm_cold_fusion = False # TODO(hirofumi): remove later if not hasattr(args, 'focal_loss_weight'): args.focal_loss_weight = 0.0 args.focal_loss_gamma = 2.0 if not hasattr(args, 'tie_embedding'): args.tie_embedding = False # Decoder dec = Decoder( sos=self.sos, eos=self.eos, pad=self.pad, enc_nunits=self.enc_nunits, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_nchannels, attn_conv_kernel_size=args.attn_conv_width, attn_nheads=args.attn_nheads, dropout_att=args.dropout_att, rnn_type=args.dec_type, nunits=args.dec_nunits, nlayers=args.dec_nlayers, residual=args.dec_residual, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=self.vocab, logits_temp=args.logits_temp, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, ss_prob=args.ss_prob, lsm_prob=args.lsm_prob, layer_norm=args.layer_norm, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, init_with_enc=args.init_with_enc, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[int(fc) for fc in args.ctc_fc_list.split('_')] if len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), rnnlm_cold_fusion=args.rnnlm_cold_fusion, cold_fusion=args.cold_fusion, internal_lm=args.internal_lm, rnnlm_init=args.rnnlm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, vocab_char=args.vocab_sub1) setattr(self, 'dec_' + dir, dec) # sub task (only for fwd) for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: # Decoder dec_fwd_sub = Decoder( sos=self.sos, eos=self.eos, pad=self.pad, enc_nunits=self.enc_nunits, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_nchannels, attn_conv_kernel_size=args.attn_conv_width, attn_nheads=1, dropout_att=args.dropout_att, rnn_type=args.dec_type, nunits=args.dec_nunits, nlayers=args.dec_nlayers, residual=args.dec_residual, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), logits_temp=args.logits_temp, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, ss_prob=args.ss_prob, lsm_prob=args.lsm_prob, layer_norm=args.layer_norm, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, init_with_enc=args.init_with_enc, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_fc_list=[int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, internal_lm=args.internal_lm, lmobj_weight=getattr(args, 'lmobj_weight_' + sub), share_lm_softmax=args.share_lm_softmax, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch) setattr(self, 'dec_fwd_' + sub, dec_fwd_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed_in = dec.embed else: self.embed_in = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize weight matrices self.init_weights(args.param_init, dist=args.param_init_dist, ignore_keys=['bias']) # Initialize CNN layers like chainer self.init_weights(args.param_init, dist='lecun', keys=['conv'], ignore_keys=['score']) # Initialize all biases with 0 self.init_weights(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: # encoder if args.enc_type != 'cnn': self.init_weights(args.param_init, dist='orthogonal', keys=[args.enc_type, 'weight'], ignore_keys=['bias']) # TODO(hirofumi): in case of CNN + LSTM # decoder self.init_weights(args.param_init, dist='orthogonal', keys=[args.dec_type, 'weight'], ignore_keys=['bias']) # Initialize bias in forget gate with 1 self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 if args.rnnlm_cold_fusion: self.init_weights(-1, dist='constant', keys=['cf_linear_lm_gate.fc.bias'])