def __init__(self, eos, unk, pad, blank, enc_n_units, attn_type, attn_n_heads, n_layers, d_model, d_ff, vocab, tie_embedding=False, pe_type='add', layer_norm_eps=1e-12, dropout=0.0, dropout_emb=0.0, dropout_att=0.0, lsm_prob=0.0, focal_loss_weight=0.0, focal_loss_gamma=2.0, ctc_weight=0.0, ctc_lsm_prob=0.0, ctc_fc_list=[], backward=False, global_weight=1.0, mtl_per_batch=False, adaptive_softmax=False): super(TransformerDecoder, self).__init__() logger = logging.getLogger('training') self.eos = eos self.unk = unk self.pad = pad self.blank = blank self.enc_n_units = enc_n_units self.d_model = d_model self.n_layers = n_layers self.attn_n_heads = attn_n_heads self.pe_type = pe_type self.lsm_prob = lsm_prob self.focal_loss_weight = focal_loss_weight self.focal_loss_gamma = focal_loss_gamma self.ctc_weight = ctc_weight self.bwd = backward self.global_weight = global_weight self.mtl_per_batch = mtl_per_batch if ctc_weight > 0: self.ctc = CTC(eos=eos, blank=blank, enc_n_units=enc_n_units, vocab=vocab, dropout=dropout, lsm_prob=ctc_lsm_prob, fc_list=ctc_fc_list, param_init=0.1) if ctc_weight < global_weight: self.embed = Embedding( vocab, d_model, dropout=0, # NOTE: do not apply dropout here ignore_index=pad) self.pos_enc = PositionalEncoding(d_model, dropout_emb, pe_type) self.layers = nn.ModuleList([ TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads, dropout, dropout_att, layer_norm_eps) for _ in range(n_layers) ]) self.norm_out = nn.LayerNorm(d_model, eps=layer_norm_eps) if adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( d_model, vocab, cutoffs=[ round(self.vocab / 15), 3 * round(self.vocab / 15) ], # cutoffs=[self.vocab // 25, 3 * self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = Linear(d_model, vocab) # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if tie_embedding: self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters()
def __init__(self, args, save_path=None): super(LMBase, self).__init__() logger = logging.getLogger('training') logger.info(self.__class__.__name__) self.save_path = save_path self.emb_dim = args.emb_dim self.n_units = args.n_units self.n_layers = args.n_layers self.lsm_prob = args.lsm_prob self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed = Embedding(vocab=self.vocab, emb_dim=args.emb_dim, dropout=args.dropout_in, ignore_index=self.pad) model_size = args.lm_type.replace('gated_conv_', '') blocks = OrderedDict() if model_size == 'custom': blocks['conv1'] = GLUBlock(args.kernel_size, args.emb_dim, args.n_units, bottlececk_dim=args.n_projs, dropout=args.dropout_hidden) for l in range(args.n_layers - 1): blocks['conv%d' % (l + 2)] = GLUBlock( args.kernel_size, args.n_units, args.n_units, bottlececk_dim=args.n_projs, dropout=args.dropout_hidden) last_dim = args.n_units elif model_size == '8': blocks['conv1'] = GLUBlock(4, args.emb_dim, 900, dropout=args.dropout_hidden) for i in range(1, 8, 1): blocks['conv2-%d' % i] = GLUBlock(4, 900, 900, dropout=args.dropout_hidden) last_dim = 900 elif model_size == '8B': blocks['conv1'] = GLUBlock(1, args.emb_dim, 512, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv2-%d' % i] = GLUBlock(5, 512, 512, bottlececk_dim=128, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv3-%d' % i] = GLUBlock(5, 512, 512, bottlececk_dim=256, dropout=args.dropout_hidden) blocks['conv4'] = GLUBlock(1, 512, 2048, bottlececk_dim=1024, dropout=args.dropout_hidden) last_dim = 2048 elif model_size == '9': blocks['conv1'] = GLUBlock(4, args.emb_dim, 807, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv2-%d-1' % i] = GLUBlock( 4, 807, 807, dropout=args.dropout_hidden) blocks['conv2-%d-2' % i] = GLUBlock( 4, 807, 807, dropout=args.dropout_hidden) last_dim = 807 elif model_size == '13': blocks['conv1'] = GLUBlock(4, args.emb_dim, 1268, dropout=args.dropout_hidden) for i in range(1, 13, 1): blocks['conv2-%d' % i] = GLUBlock(4, 1268, 1268, dropout=args.dropout_hidden) last_dim = 1268 elif model_size == '14': for i in range(1, 4, 1): blocks['conv1-%d' % i] = GLUBlock( 6, args.emb_dim if i == 1 else 850, 850, dropout=args.dropout_hidden) blocks['conv2'] = GLUBlock(1, 850, 850, dropout=args.dropout_hidden) for i in range(1, 5, 1): blocks['conv3-%d' % i] = GLUBlock(5, 850, 850, dropout=args.dropout_hidden) blocks['conv4'] = GLUBlock(1, 850, 850, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv5-%d' % i] = GLUBlock(4, 850, 850, dropout=args.dropout_hidden) blocks['conv6'] = GLUBlock(4, 850, 1024, dropout=args.dropout_hidden) blocks['conv7'] = GLUBlock(4, 1024, 2048, dropout=args.dropout_hidden) last_dim = 2048 elif model_size == '14B': blocks['conv1'] = GLUBlock(5, args.emb_dim, 512, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv2-%d' % i] = GLUBlock(5, 512, 512, bottlececk_dim=128, dropout=args.dropout_hidden) for i in range(1, 4, 1): blocks['conv3-%d' % i] = GLUBlock(5, 512 if i == 1 else 1024, 1024, bottlececk_dim=512, dropout=args.dropout_hidden) for i in range(1, 7, 1): blocks['conv4-%d' % i] = GLUBlock(5, 1024 if i == 1 else 2048, 2048, bottlececk_dim=1024, dropout=args.dropout_hidden) blocks['conv5'] = GLUBlock(5, 2048, 4096, bottlececk_dim=1024, dropout=args.dropout_hidden) last_dim = 4096 else: raise NotImplementedError(model_size) self.blocks = nn.Sequential(blocks) if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( last_dim, self.vocab, # cutoffs=[self.vocab // 10, 3 * self.vocab // 10], cutoffs=[self.vocab // 25, self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = LinearND(last_dim, self.vocab, dropout=args.dropout_out) # NOTE: include bias even when tying weights # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if args.tie_embedding: if args.n_units != args.emb_dim: raise ValueError( 'When using the tied flag, n_units must be equal to emb_dim.' ) self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters(args.param_init)
def __init__(self, args, save_path=None): super(ModelBase, self).__init__() self.save_path = save_path # for encoder, decoder self.input_type = args.input_type self.input_dim = args.input_dim self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru', 'conv_blstm', 'conv_bgru']: self.enc_n_units *= 2 self.dec_type = args.dec_type # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.gaussian_noise = args.gaussian_noise self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.is_specaug = args.n_freq_masks > 0 or args.n_time_masks > 0 self.specaug = None if self.is_specaug: assert args.n_stacks == 1 and args.n_skips == 1 assert args.n_splices == 1 self.specaug = SpecAugment(F=args.freq_width, T=args.time_width, n_freq_masks=args.n_freq_masks, n_time_masks=args.n_time_masks, p=args.time_width_upper) # Frontend self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0, param_init=args.param_init) # Encoder self.enc = select_encoder(args) if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Load the LM for LM fusion if args.lm_fusion and dir == 'fwd': lm_fusion = RNNLM(args.lm_conf) lm_fusion, _ = load_checkpoint(lm_fusion, args.lm_fusion) else: lm_fusion = None # TODO(hirofumi): for backward RNNLM # Load the LM for LM initialization if args.lm_init and dir == 'fwd': lm_init = RNNLM(args.lm_conf) lm_init, _ = load_checkpoint(lm_init, args.lm_init) else: lm_init = None # TODO(hirofumi): for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, vocab=self.vocab, tie_embedding=args.tie_embedding, pe_type=args.pe_type, layer_norm_eps=args.layer_norm_eps, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) elif 'transducer' in args.dec_type: dec = RNNTransducer( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, lsm_prob=args.lsm_prob, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], lm_init=lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, param_init=args.param_init) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, vocab=self.vocab, tie_embedding=args.tie_embedding, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, zoneout=args.zoneout, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), lm_fusion=lm_fusion, lm_fusion_type=args.lm_fusion_type, discourse_aware=args.discourse_aware, lm_init=lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax, param_init=args.param_init, replace_sos=args.replace_sos) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, focal_loss_weight=args.focal_loss_weight, focal_loss_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_lsm_prob=args.ctc_lsm_prob, ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch, param_init=args.param_init) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed = dec.embed else: self.embed = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Fix all parameters except for the gating parts in deep fusion if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False
def __init__(self, args, save_path=None): super(LMBase, self).__init__() logger = logging.getLogger('training') logger.info(self.__class__.__name__) self.save_path = save_path self.emb_dim = args.emb_dim self.rnn_type = args.lm_type assert args.lm_type in ['lstm', 'gru'] self.n_units = args.n_units self.n_projs = args.n_projs self.n_layers = args.n_layers self.residual = args.residual self.use_glu = args.use_glu self.n_units_cv = args.n_units_null_context self.lsm_prob = args.lsm_prob self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed = Embedding(vocab=self.vocab, emb_dim=args.emb_dim, dropout=args.dropout_in, ignore_index=self.pad) rnn = nn.LSTM if args.lm_type == 'lstm' else nn.GRU self.rnn = nn.ModuleList() self.dropout = nn.ModuleList( [nn.Dropout(p=args.dropout_hidden) for _ in range(args.n_layers)]) if args.n_projs > 0: self.proj = nn.ModuleList([ Linear(args.n_units, args.n_projs) for _ in range(args.n_layers) ]) rnn_idim = args.emb_dim + args.n_units_null_context for l in range(args.n_layers): self.rnn += [ rnn(rnn_idim, args.n_units, 1, bias=True, batch_first=True, dropout=0, bidirectional=False) ] rnn_idim = args.n_units if args.n_projs > 0: rnn_idim = args.n_projs if self.use_glu: self.fc_glu = Linear(rnn_idim, rnn_idim * 2, dropout=args.dropout_hidden) if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( rnn_idim, self.vocab, # cutoffs=[self.vocab // 10, 3 * self.vocab // 10], cutoffs=[self.vocab // 25, self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = Linear(rnn_idim, self.vocab, dropout=args.dropout_out) # NOTE: include bias even when tying weights # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if args.tie_embedding: if args.n_units != args.emb_dim: raise ValueError( 'When using the tied flag, n_units must be equal to emb_dim.' ) self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters(args.param_init) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight'])
def __init__(self, args): super(ModelBase, self).__init__() self.emb_dim = args.emb_dim self.rnn_type = args.lm_type assert args.lm_type in ['lstm', 'gru'] self.n_units = args.n_units self.n_layers = args.n_layers self.residual = args.residual self.use_glu = args.use_glu self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed = Embedding(vocab=self.vocab, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) self.fast_impl = False if args.n_projs == 0 and not args.residual: self.fast_impl = True if 'lstm' in args.lm_type: rnn = nn.LSTM elif 'gru' in args.lm_type: rnn = nn.GRU else: raise ValueError('rnn_type must be "(b)lstm" or "(b)gru".') self.rnn = rnn(args.emb_dim, args.n_units, args.n_layers, bias=True, batch_first=True, dropout=args.dropout_hidden, bidirectional=False) # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer rnn_idim = args.n_units self.dropout_top = nn.Dropout(p=args.dropout_hidden) else: self.rnn = torch.nn.ModuleList() self.dropout = torch.nn.ModuleList() if args.n_projs > 0: self.proj = torch.nn.ModuleList() rnn_idim = args.emb_dim for l in range(args.n_layers): self.rnn += [getattr(nn, args.lm_type.upper())( rnn_idim, args.n_units, 1, bias=True, batch_first=True, dropout=0, bidirectional=False)] self.dropout += [nn.Dropout(p=args.dropout_hidden)] rnn_idim = args.n_units if l != self.n_layers - 1 and args.n_projs > 0: self.proj += [LinearND(rnn_idim, args.n_projs)] rnn_idim = args.n_projs if self.use_glu: self.fc_glu = LinearND(rnn_idim, rnn_idim * 2, dropout=args.dropout_hidden) if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( rnn_idim, self.vocab, # cutoffs=[self.vocab // 10, 3 * self.vocab // 10], cutoffs=[self.vocab // 25, self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = LinearND(rnn_idim, self.vocab, dropout=args.dropout_out) # NOTE: include bias even when tying weights # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if args.tie_embedding: if args.n_units != args.emb_dim: raise ValueError('When using the tied flag, n_units must be equal to emb_dim.') self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters(args.param_init) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight'])
def __init__(self, args, save_path=None): super(LMBase, self).__init__() logger = logging.getLogger('training') logger.info(self.__class__.__name__) self.save_path = save_path self.d_model = args.d_model self.d_ff = args.d_ff self.pe_type = args.pe_type self.n_layers = args.n_layers self.n_heads = args.attn_n_heads self.tie_embedding = args.tie_embedding self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # self.lsm_prob = lsm_prob # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed = Embedding( vocab=self.vocab, emb_dim=self.d_model, dropout=0, # NOTE: do not apply dropout here ignore_index=self.pad) self.pos_enc = PositionalEncoding(args.d_model, args.dropout_emb, args.pe_type) self.layers = nn.ModuleList([ TransformerDecoderBlock(args.d_model, args.d_ff, args.attn_type, args.attn_n_heads, args.dropout_hidden, args.dropout_att, args.layer_norm_eps, src_attention=False) for _ in range(self.n_layers) ]) self.norm_out = nn.LayerNorm(args.d_model, eps=args.layer_norm_eps) if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( args.d_model, self.vocab, cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)], # cutoffs=[self.vocab // 25, 3 * self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = LinearND(self.d_model, self.vocab) # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if args.tie_embedding: self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters()
def __init__(self, args): super(ModelBase, self).__init__() self.emb_dim = args.emb_dim self.n_units = args.n_units self.n_layers = args.n_layers self.tie_embedding = args.tie_embedding self.backward = args.backward self.vocab = args.vocab self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for cache self.cache_theta = 0.2 # smoothing parameter self.cache_lambda = 0.2 # cache weight self.cache_ids = [] self.cache_keys = [] self.cache_attn = [] self.embed = Embedding(vocab=self.vocab, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) layers = OrderedDict() model_size = args.lm_type.replace('gated_conv_', '') if model_size == 'small': layers['conv1-1'] = GLUBlock(4, args.emb_dim, 600, bottlececk_dim=300, dropout=args.dropout_hidden) layers['conv2-1'] = GLUBlock(4, 600, 600, bottlececk_dim=300, dropout=args.dropout_hidden) layers['conv3-1'] = GLUBlock(4, 600, 600, bottlececk_dim=300, dropout=args.dropout_hidden) layers['conv4-1'] = GLUBlock(4, 600, 600, bottlececk_dim=300, dropout=args.dropout_hidden) layers['conv5-1'] = GLUBlock(4, 600, 600, bottlececk_dim=300, dropout=args.dropout_hidden) last_dim = 600 elif model_size == '8': layers['conv1-1'] = GLUBlock(4, args.emb_dim, 900, dropout=args.dropout_hidden) for i in range(1, 8, 1): layers['conv2-%d' % i] = GLUBlock(4, 900, 900, dropout=args.dropout_hidden) last_dim = 900 elif model_size == '8B': raise NotImplementedError elif model_size == '9': raise NotImplementedError elif model_size == '13': layers['conv1-1'] = GLUBlock(4, args.emb_dim, 1268, dropout=args.dropout_hidden) for i in range(1, 13, 1): layers['conv2-%d' % i] = GLUBlock(4, 1268, 1268, dropout=args.dropout_hidden) last_dim = 1268 elif model_size == '14': for i in range(1, 4, 1): layers['conv1-%d' % i] = GLUBlock( 6, args.emb_dim if i == 1 else 850, 850, dropout=args.dropout_hidden) layers['conv2-1'] = GLUBlock(1, 850, 850, dropout=args.dropout_hidden) for i in range(1, 5, 1): layers['conv3-%d' % i] = GLUBlock(5, 850, 850, dropout=args.dropout_hidden) layers['conv4-1'] = GLUBlock(1, 850, 850, dropout=args.dropout_hidden) for i in range(1, 4, 1): layers['conv5-%d' % i] = GLUBlock(4, 850, 850, dropout=args.dropout_hidden) layers['conv6-1'] = GLUBlock(4, 850, 1024, dropout=args.dropout_hidden) layers['conv7-1'] = GLUBlock(4, 1024, 2048, dropout=args.dropout_hidden) last_dim = 2048 elif model_size == '14B': layers['conv1-1'] = GLUBlock(5, args.emb_dim, 512, dropout=args.dropout_hidden) for i in range(1, 4, 1): layers['conv2-%d' % i] = GLUBlock(5, 512, 512, bottlececk_dim=128, dropout=args.dropout_hidden) for i in range(1, 4, 1): layers['conv3-%d' % i] = GLUBlock(5, 512 if i == 1 else 1024, 1024, bottlececk_dim=512, dropout=args.dropout_hidden) for i in range(1, 7, 1): layers['conv4-%d' % i] = GLUBlock(5, 1024 if i == 1 else 2048, 2048, bottlececk_dim=1024, dropout=args.dropout_hidden) layers['conv5-1'] = GLUBlock(5, 2048, 4096, bottlececk_dim=1024, dropout=args.dropout_hidden) last_dim = 4096 else: raise NotImplementedError(model_size) self.layers = nn.Sequential(layers) if args.adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( last_dim, self.vocab, # cutoffs=[self.vocab // 10, 3 * self.vocab // 10], cutoffs=[self.vocab // 25, self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = LinearND(last_dim, self.vocab, dropout=args.dropout_out) # NOTE: include bias even when tying weights # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if args.tie_embedding: if args.n_units != args.emb_dim: raise ValueError( 'When using the tied flag, n_units must be equal to emb_dim.' ) self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters(args.param_init, dist=args.param_init_dist) # Initialize bias vectors with zero self.reset_parameters(0, dist='constant', keys=['bias'])
def __init__(self, eos, unk, pad, blank, enc_n_units, attn_type, attn_n_heads, n_layers, d_model, d_ff, pe_type, tie_embedding, vocab, dropout=0.0, dropout_emb=0.0, dropout_att=0.0, lsm_prob=0.0, layer_norm_eps=1e-6, ctc_weight=0.0, ctc_fc_list=[], backward=False, global_weight=1.0, mtl_per_batch=False, adaptive_softmax=False): super(TransformerDecoder, self).__init__() self.eos = eos self.unk = unk self.pad = pad self.blank = blank self.enc_n_units = enc_n_units self.d_model = d_model self.n_layers = n_layers self.pe_type = pe_type self.lsm_prob = lsm_prob self.ctc_weight = ctc_weight self.ctc_fc_list = ctc_fc_list self.backward = backward self.global_weight = global_weight self.mtl_per_batch = mtl_per_batch if ctc_weight > 0: # Fully-connected layers for CTC if len(ctc_fc_list) > 0: fc_layers = OrderedDict() for i in range(len(ctc_fc_list)): input_dim = d_model if i == 0 else ctc_fc_list[i - 1] fc_layers['fc' + str(i)] = LinearND(input_dim, ctc_fc_list[i], dropout=dropout) fc_layers['fc' + str(len(ctc_fc_list))] = LinearND(ctc_fc_list[-1], vocab, dropout=0) self.output_ctc = nn.Sequential(fc_layers) else: self.output_ctc = LinearND(d_model, vocab) self.decode_ctc_greedy = GreedyDecoder(blank=blank) self.decode_ctc_beam = BeamSearchDecoder(blank=blank) import warpctc_pytorch self.warpctc_loss = warpctc_pytorch.CTCLoss(size_average=True) if ctc_weight < global_weight: self.layers = nn.ModuleList( [TransformerDecoderBlock(d_model, d_ff, attn_type, attn_n_heads, dropout, dropout_att, layer_norm_eps) for _ in range(n_layers)]) self.embed = Embedding(vocab, d_model, dropout=0, # NOTE: do not apply dropout here ignore_index=pad) if pe_type: self.pos_emb_out = PositionalEncoding(d_model, dropout_emb, pe_type) if adaptive_softmax: self.adaptive_softmax = nn.AdaptiveLogSoftmaxWithLoss( d_model, vocab, cutoffs=[round(self.vocab / 15), 3 * round(self.vocab / 15)], # cutoffs=[self.vocab // 25, 3 * self.vocab // 5], div_value=4.0) self.output = None else: self.adaptive_softmax = None self.output = LinearND(d_model, vocab) # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if tie_embedding: self.output.fc.weight = self.embed.embed.weight self.norm_top = nn.LayerNorm(d_model, eps=layer_norm_eps) # Initialize parameters self.reset_parameters()
def __init__(self, eos, unk, pad, blank, enc_n_units, rnn_type, n_units, n_projs, n_layers, bottleneck_dim, emb_dim, vocab, tie_embedding=False, attn_conv_kernel_size=0, dropout=0.0, dropout_emb=0.0, lsm_prob=0.0, ctc_weight=0.0, ctc_lsm_prob=0.0, ctc_fc_list=[], backward=False, lm_fusion=None, lm_fusion_type='cold', discourse_aware='', lm_init=None, global_weight=1.0, mtl_per_batch=False, param_init=0.1, replace_sos=False, soft_label_weight=0.0): super(CIFRNNDecoder, self).__init__() logger = logging.getLogger('training') self.eos = eos self.unk = unk self.pad = pad self.blank = blank self.vocab = vocab self.rnn_type = rnn_type assert rnn_type in ['lstm', 'gru'] self.enc_n_units = enc_n_units self.dec_n_units = n_units self.n_projs = n_projs self.n_layers = n_layers self.lsm_prob = lsm_prob self.ctc_weight = ctc_weight self.bwd = backward self.lm_fusion_type = lm_fusion_type self.global_weight = global_weight self.mtl_per_batch = mtl_per_batch self.replace_sos = replace_sos self.soft_label_weight = soft_label_weight self.quantity_loss_weight = 1.0 # for contextualization self.discourse_aware = discourse_aware self.dstate_prev = None # for cache self.prev_spk = '' self.total_step = 0 self.dstates_final = None self.lmstate_final = None if ctc_weight > 0: self.ctc = CTC(eos=eos, blank=blank, enc_n_units=enc_n_units, vocab=vocab, dropout=dropout, lsm_prob=ctc_lsm_prob, fc_list=ctc_fc_list, param_init=param_init) if ctc_weight < global_weight: # Attention layer self.score = CIF(enc_dim=self.enc_n_units, conv_kernel_size=attn_conv_kernel_size, conv_out_channels=self.enc_n_units) # Decoder self.rnn = nn.ModuleList() if self.n_projs > 0: self.proj = nn.ModuleList( [Linear(n_units, n_projs) for _ in range(n_layers)]) self.dropout = nn.ModuleList( [nn.Dropout(p=dropout) for _ in range(n_layers)]) rnn = nn.LSTM if rnn_type == 'lstm' else nn.GRU dec_odim = enc_n_units + emb_dim for l in range(n_layers): self.rnn += [rnn(dec_odim, n_units, 1)] dec_odim = n_units if self.n_projs > 0: dec_odim = n_projs # LM fusion if lm_fusion is not None: self.linear_dec_feat = Linear(dec_odim + enc_n_units, n_units) if lm_fusion_type in ['cold', 'deep']: self.linear_lm_feat = Linear(lm_fusion.n_units, n_units) self.linear_lm_gate = Linear(n_units * 2, n_units) elif lm_fusion_type == 'cold_prob': self.linear_lm_feat = Linear(lm_fusion.vocab, n_units) self.linear_lm_gate = Linear(n_units * 2, n_units) else: raise ValueError(lm_fusion_type) self.output_bn = Linear(n_units * 2, bottleneck_dim) # fix LM parameters for p in lm_fusion.parameters(): p.requires_grad = False elif discourse_aware == 'hierarchical': raise NotImplementedError else: self.output_bn = Linear(dec_odim + enc_n_units, bottleneck_dim) self.embed = Embedding(vocab, emb_dim, dropout=dropout_emb, ignore_index=pad) self.output = Linear(bottleneck_dim, vocab) # NOTE: include bias even when tying weights # Optionally tie weights as in: # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016) # https://arxiv.org/abs/1608.05859 # and # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016) # https://arxiv.org/abs/1611.01462 if tie_embedding: if emb_dim != bottleneck_dim: raise ValueError( 'When using the tied flag, n_units must be equal to emb_dim.' ) self.output.fc.weight = self.embed.embed.weight # Initialize parameters self.reset_parameters(param_init) # resister the external LM self.lm = lm_fusion # decoder initialization with pre-trained LM if lm_init is not None: assert lm_init.vocab == vocab assert lm_init.n_units == n_units assert lm_init.emb_dim == emb_dim logger.info('===== Initialize the decoder with pre-trained RNNLM') assert lm_init.n_projs == 0 # TODO(hirofumi): fix later assert lm_init.n_units_null_context == enc_n_units # RNN for l in range(lm_init.n_layers): for n, p in lm_init.rnn[l].named_parameters(): assert getattr(self.rnn[l], n).size() == p.size() getattr(self.rnn[l], n).data = p.data logger.info('Overwrite %s' % n) # embedding assert self.embed.embed.weight.size( ) == lm_init.embed.embed.weight.size() self.embed.embed.weight.data = lm_init.embed.embed.weight.data logger.info('Overwrite %s' % 'embed.embed.weight')
def __init__(self, eos, unk, pad, blank, enc_n_units, rnn_type, n_units, n_projs, n_layers, residual, bottleneck_dim, emb_dim, vocab, tie_embedding=False, dropout=0.0, dropout_emb=0.0, lsm_prob=0.0, ctc_weight=0.0, ctc_lsm_prob=0.0, ctc_fc_list=[], lm_init=None, lmobj_weight=0.0, share_lm_softmax=False, global_weight=1.0, mtl_per_batch=False, param_init=0.1, start_pointing=False, end_pointing=True): super(RNNTransducer, self).__init__() logger = logging.getLogger('training') self.eos = eos self.unk = unk self.pad = pad self.blank = blank self.vocab = vocab self.rnn_type = rnn_type assert rnn_type in ['lstm_transducer', 'gru_transducer'] self.enc_n_units = enc_n_units self.dec_n_units = n_units self.n_projs = n_projs self.n_layers = n_layers self.residual = residual self.lsm_prob = lsm_prob self.ctc_weight = ctc_weight self.lmobj_weight = lmobj_weight self.share_lm_softmax = share_lm_softmax self.global_weight = global_weight self.mtl_per_batch = mtl_per_batch # VAD self.start_pointing = start_pointing self.end_pointing = end_pointing # for cache self.prev_spk = '' self.lmstate_final = None self.state_cache = OrderedDict() if ctc_weight > 0: self.ctc = CTC(eos=eos, blank=blank, enc_n_units=enc_n_units, vocab=vocab, dropout=dropout, lsm_prob=ctc_lsm_prob, fc_list=ctc_fc_list, param_init=param_init) if ctc_weight < global_weight: import warprnnt_pytorch self.warprnnt_loss = warprnnt_pytorch.RNNTLoss() # for MTL with LM objective if lmobj_weight > 0: if share_lm_softmax: self.output_lmobj = self.output # share paramters else: self.output_lmobj = Linear(n_units, vocab) # Prediction network self.fast_impl = False rnn = nn.LSTM if rnn_type == 'lstm_transducer' else nn.GRU if n_projs == 0 and not residual: self.fast_impl = True self.rnn = rnn(emb_dim, n_units, n_layers, bias=True, batch_first=True, dropout=dropout, bidirectional=False) # NOTE: pytorch introduces a dropout layer on the outputs of each layer EXCEPT the last layer dec_idim = n_units self.dropout_top = nn.Dropout(p=dropout) else: self.rnn = nn.ModuleList() self.dropout = nn.ModuleList([nn.Dropout(p=dropout) for _ in range(n_layers)]) if n_projs > 0: self.proj = nn.ModuleList([Linear(dec_idim, n_projs) for _ in range(n_layers)]) dec_idim = emb_dim for l in range(n_layers): self.rnn += [rnn(dec_idim, n_units, 1, bias=True, batch_first=True, dropout=0, bidirectional=False)] dec_idim = n_projs if n_projs > 0 else n_units self.embed = Embedding(vocab, emb_dim, dropout=dropout_emb, ignore_index=pad) self.w_enc = Linear(enc_n_units, bottleneck_dim, bias=True) self.w_dec = Linear(dec_idim, bottleneck_dim, bias=False) self.output = Linear(bottleneck_dim, vocab) # Initialize parameters self.reset_parameters(param_init) # prediction network initialization with pre-trained LM if lm_init is not None: assert lm_init.vocab == vocab assert lm_init.n_units == n_units assert lm_init.n_projs == n_projs assert lm_init.n_layers == n_layers assert lm_init.residual == residual param_dict = dict(lm_init.named_parameters()) for n, p in self.named_parameters(): if n in param_dict.keys() and p.size() == param_dict[n].size(): if 'output' in n: continue p.data = param_dict[n].data logger.info('Overwrite %s' % n)
def __init__(self, args): super(ModelBase, self).__init__() # for encoder self.input_type = args.input_type self.input_dim = args.input_dim self.n_stacks = args.n_stacks self.n_skips = args.n_skips self.n_splices = args.n_splices self.enc_type = args.enc_type self.enc_n_units = args.enc_n_units if args.enc_type in ['blstm', 'bgru']: self.enc_n_units *= 2 self.bridge_layer = args.bridge_layer # for OOV resolution self.enc_n_layers = args.enc_n_layers self.enc_n_layers_sub1 = args.enc_n_layers_sub1 self.subsample = [int(s) for s in args.subsample.split('_')] # for attention layer self.attn_n_heads = args.attn_n_heads # for decoder self.vocab = args.vocab self.vocab_sub1 = args.vocab_sub1 self.vocab_sub2 = args.vocab_sub2 self.blank = 0 self.unk = 1 self.eos = 2 self.pad = 3 # NOTE: reserved in advance # for the sub tasks self.main_weight = 1 - args.sub1_weight - args.sub2_weight self.sub1_weight = args.sub1_weight self.sub2_weight = args.sub2_weight self.mtl_per_batch = args.mtl_per_batch self.task_specific_layer = args.task_specific_layer # for CTC self.ctc_weight = min(args.ctc_weight, self.main_weight) self.ctc_weight_sub1 = min(args.ctc_weight_sub1, self.sub1_weight) self.ctc_weight_sub2 = min(args.ctc_weight_sub2, self.sub2_weight) # for backward decoder self.bwd_weight = min(args.bwd_weight, self.main_weight) self.fwd_weight = self.main_weight - self.bwd_weight - self.ctc_weight self.fwd_weight_sub1 = self.sub1_weight - self.ctc_weight_sub1 self.fwd_weight_sub2 = self.sub2_weight - self.ctc_weight_sub2 # Feature extraction self.ssn = None if args.sequence_summary_network: assert args.input_type == 'speech' self.ssn = SequenceSummaryNetwork(args.input_dim, n_units=512, n_layers=3, bottleneck_dim=100, dropout=0) # Encoder if args.enc_type == 'transformer': self.enc = TransformerEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_enc_n_layers, d_model=args.d_model, d_ff=args.d_ff, # pe_type=args.pe_type, pe_type=False, dropout_in=args.dropout_in, dropout=args.dropout_enc, dropout_att=args.dropout_att, layer_norm_eps=args.layer_norm_eps, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim) else: self.enc = RNNEncoder( input_dim=args.input_dim if args.input_type == 'speech' else args.emb_dim, rnn_type=args.enc_type, n_units=args.enc_n_units, n_projs=args.enc_n_projs, n_layers=args.enc_n_layers, n_layers_sub1=args.enc_n_layers_sub1, n_layers_sub2=args.enc_n_layers_sub2, dropout_in=args.dropout_in, dropout=args.dropout_enc, subsample=list(map(int, args.subsample.split('_'))) + [1] * (args.enc_n_layers - len(args.subsample.split('_'))), subsample_type=args.subsample_type, n_stacks=args.n_stacks, n_splices=args.n_splices, conv_in_channel=args.conv_in_channel, conv_channels=args.conv_channels, conv_kernel_sizes=args.conv_kernel_sizes, conv_strides=args.conv_strides, conv_poolings=args.conv_poolings, conv_batch_norm=args.conv_batch_norm, conv_residual=args.conv_residual, conv_bottleneck_dim=args.conv_bottleneck_dim, residual=args.enc_residual, nin=args.enc_nin, task_specific_layer=args.task_specific_layer) # NOTE: pure CNN/TDS encoders are also included if args.freeze_encoder: for p in self.enc.parameters(): p.requires_grad = False # Bridge layer between the encoder and decoder self.is_bridge = False if (args.enc_type in ['conv', 'tds', 'gated_conv', 'transformer'] and args.ctc_weight < 1 ) or args.dec_type == 'transformer' or args.bridge_layer: self.bridge = LinearND(self.enc.output_dim, args.d_model if args.dec_type == 'transformer' else args.dec_n_units, dropout=args.dropout_enc) self.is_bridge = True if self.sub1_weight > 0: self.bridge_sub1 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) if self.sub2_weight > 0: self.bridge_sub2 = LinearND(self.enc.output_dim, args.dec_n_units, dropout=args.dropout_enc) self.enc_n_units = args.dec_n_units # main task directions = [] if self.fwd_weight > 0 or self.ctc_weight > 0: directions.append('fwd') if self.bwd_weight > 0: directions.append('bwd') for dir in directions: # Cold fusion if args.lm_fusion and dir == 'fwd': lm = RNNLM(args.lm_conf) lm, _ = load_checkpoint(lm, args.lm_fusion) else: args.lm_conf = False lm = None # TODO(hirofumi): cold fusion for backward RNNLM # Decoder if args.dec_type == 'transformer': dec = TransformerDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.transformer_attn_type, attn_n_heads=args.transformer_attn_n_heads, n_layers=args.transformer_dec_n_layers, d_model=args.d_model, d_ff=args.d_ff, pe_type=args.pe_type, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, lsm_prob=args.lsm_prob, layer_norm_eps=args.layer_norm_eps, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], backward=(dir == 'bwd'), global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch) else: dec = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc.output_dim, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=args.attn_n_heads, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=self.vocab, dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=self.ctc_weight if dir == 'fwd' else 0, ctc_fc_list=[ int(fc) for fc in args.ctc_fc_list.split('_') ] if args.ctc_fc_list is not None and len(args.ctc_fc_list) > 0 else [], input_feeding=args.input_feeding, backward=(dir == 'bwd'), # lm=args.lm_conf, lm=lm, # TODO(hirofumi): load RNNLM in the model init. lm_fusion_type=args.lm_fusion_type, contextualize=args.contextualize, lm_init=args.lm_init, lmobj_weight=args.lmobj_weight, share_lm_softmax=args.share_lm_softmax, global_weight=self.main_weight - self.bwd_weight if dir == 'fwd' else self.bwd_weight, mtl_per_batch=args.mtl_per_batch, adaptive_softmax=args.adaptive_softmax) setattr(self, 'dec_' + dir, dec) # sub task for sub in ['sub1', 'sub2']: if getattr(self, sub + '_weight') > 0: if args.dec_type == 'transformer': raise NotImplementedError else: dec_sub = RNNDecoder( eos=self.eos, unk=self.unk, pad=self.pad, blank=self.blank, enc_n_units=self.enc_n_units, attn_type=args.attn_type, attn_dim=args.attn_dim, attn_sharpening_factor=args.attn_sharpening, attn_sigmoid_smoothing=args.attn_sigmoid, attn_conv_out_channels=args.attn_conv_n_channels, attn_conv_kernel_size=args.attn_conv_width, attn_n_heads=1, rnn_type=args.dec_type, n_units=args.dec_n_units, n_projs=args.dec_n_projs, n_layers=args.dec_n_layers, loop_type=args.dec_loop_type, residual=args.dec_residual, bottleneck_dim=args.dec_bottleneck_dim, emb_dim=args.emb_dim, tie_embedding=args.tie_embedding, vocab=getattr(self, 'vocab_' + sub), dropout=args.dropout_dec, dropout_emb=args.dropout_emb, dropout_att=args.dropout_att, ss_prob=args.ss_prob, ss_type=args.ss_type, lsm_prob=args.lsm_prob, fl_weight=args.focal_loss_weight, fl_gamma=args.focal_loss_gamma, ctc_weight=getattr(self, 'ctc_weight_' + sub), ctc_fc_list=[ int(fc) for fc in getattr(args, 'ctc_fc_list_' + sub).split('_') ] if getattr(args, 'ctc_fc_list_' + sub) is not None and len(getattr(args, 'ctc_fc_list_' + sub)) > 0 else [], input_feeding=args.input_feeding, global_weight=getattr(self, sub + '_weight'), mtl_per_batch=args.mtl_per_batch) setattr(self, 'dec_fwd_' + sub, dec_sub) if args.input_type == 'text': if args.vocab == args.vocab_sub1: # Share the embedding layer between input and output self.embed_in = dec.embed else: self.embed_in = Embedding(vocab=args.vocab_sub1, emb_dim=args.emb_dim, dropout=args.dropout_emb, ignore_index=self.pad) # Initialize parameters in CNN layers self.reset_parameters( args.param_init, # dist='xavier_uniform', # dist='kaiming_uniform', dist='lecun', keys=['conv'], ignore_keys=['score']) # Initialize parameters in the encoder if args.enc_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['enc'], ignore_keys=['embed_in']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed_in']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['enc'], ignore_keys=['conv']) # Initialize parameters in the decoder if args.dec_type == 'transformer': self.reset_parameters(args.param_init, dist='xavier_uniform', keys=['dec'], ignore_keys=['embed']) self.reset_parameters(args.d_model**-0.5, dist='normal', keys=['embed']) else: self.reset_parameters(args.param_init, dist=args.param_init_dist, keys=['dec']) # Initialize bias vectors with zero self.reset_parameters(0, dist='constant', keys=['bias']) # Recurrent weights are orthogonalized if args.rec_weight_orthogonal: self.reset_parameters(args.param_init, dist='orthogonal', keys=['rnn', 'weight']) # Initialize bias in forget gate with 1 # self.init_forget_gate_bias_with_one() # Initialize bias in gating with -1 for cold fusion if args.lm_fusion: self.reset_parameters(-1, dist='constant', keys=['linear_lm_gate.fc.bias']) if args.lm_fusion_type == 'deep' and args.lm_fusion: for n, p in self.named_parameters(): if 'output' in n or 'output_bn' in n or 'linear' in n: p.requires_grad = True else: p.requires_grad = False