def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim, args) self.decoder = Decoder(odim, args) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.char_list = args.char_list # self.verbose = args.verbose self.reset_parameters(args) self.recog_args = None # unused self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='asr', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last self.enc_lambda = args.enc_lambda logging.warning("Using fixed encoder lambda: {}".format( self.enc_lambda)) logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 self.aggregation_module = torch.nn.Sequential( torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid()) self.language_divider = 1000 self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.cn_ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) self.en_ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.cn_ctc = None self.en_ctc = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 self.aggregation_module = torch.nn.Sequential( torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid()) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last # we frozen params here if args.activated_keys: activated_keys = args.activated_keys.split(',') for name, params in self.named_parameters(): requires_grad = False # by default, we'd like to frozen all params for key in activated_keys: if key in name: requires_grad = True # hit the key, activate this param params.requires_grad = requires_grad else: logging.warning("Not frozen anything.") logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) # target matching system organization self.oversampling = args.oversampling self.residual = args.residual self.outer = args.outer self.poster = torch.nn.Linear(args.adim, odim * self.oversampling) if self.outer: if self.residual: self.matcher_res = torch.nn.Linear(idim, odim) self.matcher = torch.nn.Linear(odim, odim) else: self.matcher = torch.nn.Linear(odim + idim, odim) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args): """Initialize multi-speaker E2E module.""" super(E2E, self).__init__() torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() self.num_spkrs = args.num_spkrs self.spa = args.spa self.pit = PIT(self.num_spkrs) # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn_mix") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist( odim, args.lsm_type, transcript=args.train_json ) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for(args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim, reduce=False) # attention num_att = self.num_spkrs if args.spa else 1 self.att = att_for(args, num_att) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # weight initialization self.init_like_chainer() # options for beam search if "report_cer" in vars(args) and (args.report_cer or args.report_wer): recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited')
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank