def __init__(self, args, save_path=None): super(ModelBase, self).__init__() self.save_path = save_path # for encoder, decoder self.input_type = args.input_type self.input_dim = args.input_dim self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']: self.enc_n_units *= 2 self.dec_type = args.dec_type # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.gaussian_noise = args.gaussian_noise self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.is_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0 self.specaug = None if self.is_specaug: assert args.n_stacks == 1 and args.n_skips == 1 assert args.n_splices == 1 self.specaug = SpecAugment(F=args.freq_width, T=args.time_width, n_freq_masks=args.n_freq_masks, n_time_masks=args.n_time_masks, p=args.time_width_upper) # Frontend self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0, param_init=args.param_init) # Encoder self.enc = select_encoder(args) if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Load the LM for LM fusion if args.lm_fusion and dir == 'fwd': lm_fusion = RNNLM(args.lm_conf) lm_fusion, _ = load_checkpoint(lm_fusion, args.lm_fusion) else: lm_fusion = None # TODO(hirofumi): for backward RNNLM # Load the LM for LM initialization if args.lm_init and dir == 'fwd': lm_init = RNNLM(args.lm_conf) lm_init, _ = load_checkpoint(lm_init, args.lm_init) else: lm_init = None # TODO(hirofumi): for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, vocab=self.vocab, tie_embedding=args.tie_embedding, pe_type=args.pe_type, layer_norm_eps=args.layer_norm_eps, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) elif 'transducer' in args.dec_type: dec = RNNTransducer( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, lsm_prob=args.lsm_prob, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], lm_init=lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, param_init=args.param_init) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, vocab=self.vocab, tie_embedding=args.tie_embedding, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, zoneout=args.zoneout, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), lm_fusion=lm_fusion, lm_fusion_type=args.lm_fusion_type, discourse_aware=args.discourse_aware, lm_init=lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax, param_init=args.param_init, replace_sos=args.replace_sos) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch, param_init=args.param_init) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed = dec.embed else: self.embed = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Fix all parameters except for the gating parts in deep fusion if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False
def __init__(self, args, save_path=None, idx2token=None): super(ModelBase, self).__init__() self.save_path = save_path # for encoder, decoder self.input_type = args.input_type self.input_dim = args.input_dim self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']: self.enc_n_units *= 2 self.dec_type = args.dec_type # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # for MBR self.mbr_training = args.mbr_training self.recog_params = vars(args) self.idx2token = idx2token # Feature extraction self.gaussian_noise = args.gaussian_noise self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.use_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0 self.specaug = None self.flip_time_prob = args.flip_time_prob self.flip_freq_prob = args.flip_freq_prob self.weight_noise = args.weight_noise if self.use_specaug: assert args.n_stacks == 1 and args.n_skips == 1 assert args.n_splices == 1 self.specaug = SpecAugment(F=args.freq_width, T=args.time_width, n_freq_masks=args.n_freq_masks, n_time_masks=args.n_time_masks, p=args.time_width_upper) # Frontend self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0, param_init=args.param_init) # Encoder self.enc = build_encoder(args) if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # main task external_lm = None directions = [] if self.fwd_weight > 0 or (self.bwd_weight == 0 and self.ctc_weight > 0): directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Load the LM for LM fusion and decoder initialization if args.external_lm and dir == 'fwd': external_lm = RNNLM(args.lm_conf) load_checkpoint(external_lm, args.external_lm) # freeze LM parameters for n, p in external_lm.named_parameters(): p.requires_grad = False # Decoder special_symbols = { 'blank': self.blank, 'unk': self.unk, 'eos': self.eos, 'pad': self.pad, } dec = build_decoder( args, special_symbols, self.enc.output_dim, args.vocab, self.ctc_weight, args.ctc_fc_list, self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, external_lm) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: dec_sub = build_decoder(args, special_symbols, self.enc.output_dim, getattr(self, 'vocab_' + sub), getattr(self, 'ctc_weight_' + sub), getattr(args, 'ctc_fc_list_' + sub), getattr(self, sub + '_weight'), external_lm) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed = dec.embed else: self.embed = nn.Embedding(args.vocab_sub1, args.emb_dim, padding_idx=self.pad) self.dropout_emb = nn.Dropout(p=args.dropout_emb) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Fix all parameters except for the gating parts in deep fusion if args.lm_fusion == 'deep' and external_lm is not None: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False
def __init__(self, args, save_path=None, idx2token=None): super(ModelBase, self).__init__() self.save_path = save_path # for encoder, decoder self.input_type = args.input_type self.input_dim = args.input_dim self.enc_type = args.enc_type self.dec_type = args.dec_type # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = args.total_weight - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # for MBR self.mbr_training = args.mbr_training self.recog_params = vars(args) self.idx2token = idx2token # for discourse-aware model self.utt_id_prev = None # Feature extraction self.input_noise_std = args.input_noise_std self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.weight_noise_std = args.weight_noise_std self.specaug = None if args.n_freq_masks > 0 or args.n_time_masks > 0: assert args.n_stacks == 1 and args.n_skips == 1 assert args.n_splices == 1 self.specaug = SpecAugment( F=args.freq_width, T=args.time_width, n_freq_masks=args.n_freq_masks, n_time_masks=args.n_time_masks, p=args.time_width_upper, adaptive_number_ratio=args.adaptive_number_ratio, adaptive_size_ratio=args.adaptive_size_ratio, max_n_time_masks=args.max_n_time_masks) # Frontend self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0, param_init=args.param_init) # Encoder self.enc = build_encoder(args) if args.freeze_encoder: for n, p in self.enc.named_parameters(): if 'bridge' in n or 'sub1' in n: continue p.requires_grad = False logger.info('freeze %s' % n) special_symbols = { 'blank': self.blank, 'unk': self.unk, 'eos': self.eos, 'pad': self.pad, } # main task external_lm = None directions = [] if self.fwd_weight > 0 or (self.bwd_weight == 0 and self.ctc_weight > 0): directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Load the LM for LM fusion and decoder initialization if args.external_lm and dir == 'fwd': external_lm = RNNLM(args.lm_conf) load_checkpoint(args.external_lm, external_lm) # freeze LM parameters for n, p in external_lm.named_parameters(): p.requires_grad = False # Decoder dec = build_decoder( args, special_symbols, self.enc.output_dim, args.vocab, self.ctc_weight, self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, external_lm) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: args_sub = copy.deepcopy(args) if hasattr(args, 'dec_config_' + sub): for k, v in getattr(args, 'dec_config_' + sub).items(): setattr(args_sub, k, v) # NOTE: Other parameters are the same as the main decoder dec_sub = build_decoder(args_sub, special_symbols, getattr(self.enc, 'output_dim_' + sub), getattr(self, 'vocab_' + sub), getattr(self, 'ctc_weight_' + sub), getattr(self, sub + '_weight'), external_lm) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed = dec.embed else: self.embed = nn.Embedding(args.vocab_sub1, args.emb_dim, padding_idx=self.pad) self.dropout_emb = nn.Dropout(p=args.dropout_emb) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Fix all parameters except for the gating parts in deep fusion if args.lm_fusion == 'deep' and external_lm is not None: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False