def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim, args) self.decoder = Decoder(odim, args) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.char_list = args.char_list # self.verbose = args.verbose self.reset_parameters(args) self.recog_args = None # unused self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__(idim, odim, args, ignore_id=-1) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = EncoderMix( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks_sd=args.elayers_sd, num_blocks_rec=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, num_spkrs=args.num_spkrs, ) if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=False) else: self.ctc = None self.num_spkrs = args.num_spkrs self.pit = PIT(self.num_spkrs)
def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='asr', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None self.left_window = args.dec_left_window self.right_window = args.dec_right_window
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder(idim=idim, d_model=args.adim, n_heads=args.aheads, d_ffn=args.eunits, layers=args.elayers, kernel_size=args.kernel_size, input_layer=args.input_layer, dropout_rate=args.dropout_rate, causal=args.causal) args.eprojs = args.adim if args.mtlalpha < 1.0: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.ignore_id = ignore_id self.odim = odim self.adim = args.adim self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 self.aggregation_module = torch.nn.Sequential( torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid()) self.language_divider = 1000 self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.cn_ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) self.en_ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.cn_ctc = None self.en_ctc = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.asr_weight = getattr(args, "asr_weight", 0) self.mt_weight = getattr(args, "mt_weight", 0) self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in NMT. To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info # +1 means input (+1) and layers outputs (args.elayer) subsample = np.ones(args.elayers + 1, dtype=np.int) if args.etype.endswith("p") and not args.etype.startswith("vgg"): ss = args.subsample.split("_") for j in range(min(args.elayers + 1, len(ss))): subsample[j] = int(ss[j]) else: logging.warning( 'Subsampling is not performed for vgg*. It is performed in max pooling layers at CNN.' ) logging.info('subsample: ' + ' '.join([str(x) for x in subsample])) self.subsample = subsample # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.joint_asr = getattr(args, "joint_asr", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC(odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = 'location' # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { 'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank, 'tgt_lang': False } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { 'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': 0, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank, 'tgt_lang': False } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) self.do_asr = self.asr_weight > 0 and args.mtlalpha < 1 # cross-attention parameters self.cross_weight = getattr(args, "cross_weight", 0.0) self.cross_self = getattr(args, "cross_self", False) self.cross_src = getattr(args, "cross_src", False) self.cross_operator = getattr(args, "cross_operator", None) self.cross_to_asr = getattr(args, "cross_to_asr", False) self.cross_to_st = getattr(args, "cross_to_st", False) self.num_decoders = getattr(args, "num_decoders", 1) self.wait_k_asr = getattr(args, "wait_k_asr", 0) self.wait_k_st = getattr(args, "wait_k_st", 0) self.cross_src_from = getattr(args, "cross_src_from", "embedding") self.cross_self_from = getattr(args, "cross_self_from", "embedding") self.cross_weight_learnable = getattr(args, "cross_weight_learnable", False) # one-to-many ST experiments self.one_to_many = getattr(args, "one_to_many", False) self.langs_dict = getattr(args, "langs_dict", None) self.lang_tok = getattr(args, "lang_tok", None) self.normalize_before = getattr(args, "normalize_before", True) logging.info(f'self.normalize_before = {self.normalize_before}') # Check parameters if self.cross_operator == 'sum' and self.cross_weight <= 0: assert (not self.cross_to_asr) and (not self.cross_to_st) if self.cross_to_asr or self.cross_to_st: assert self.do_asr assert self.cross_self or self.cross_src assert bool(self.cross_operator) == (self.do_asr and (self.cross_to_asr or self.cross_to_st)) if self.cross_src_from != "embedding" or self.cross_self_from != "embedding": assert self.normalize_before if self.wait_k_asr > 0: assert self.wait_k_st == 0 elif self.wait_k_st > 0: assert self.wait_k_asr == 0 else: assert self.wait_k_asr == 0 assert self.wait_k_st == 0 logging.info("*** Cross attention parameters ***") if self.cross_to_asr: logging.info("| Cross to ASR") if self.cross_to_st: logging.info("| Cross to ST") if self.cross_self: logging.info("| Cross at Self") if self.cross_src: logging.info("| Cross at Source") if self.cross_to_asr or self.cross_to_st: logging.info(f'| Cross operator: {self.cross_operator}') logging.info(f'| Cross sum weight: {self.cross_weight}') if self.cross_src: logging.info(f'| Cross source from: {self.cross_src_from}') if self.cross_self: logging.info(f'| Cross self from: {self.cross_self_from}') logging.info(f'| wait_k_asr = {self.wait_k_asr}') logging.info(f'| wait_k_st = {self.wait_k_st}') if (self.cross_src_from != "embedding" and self.cross_src) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_src_from == embedding for cross at source attention.') if (self.cross_self_from != "embedding" and self.cross_self) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_self_from == embedding for cross at self attention.') self.dual_decoder = DualDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, normalize_before=self.normalize_before, cross_operator=self.cross_operator, cross_weight_learnable=self.cross_weight_learnable, cross_weight=self.cross_weight, cross_self=self.cross_self, cross_src=self.cross_src, cross_to_asr=self.cross_to_asr, cross_to_st=self.cross_to_st ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.idim = idim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0 ) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos if self.lang_tok == "encoder-pre-sum": self.language_embeddings = build_embedding(self.langs_dict, self.idim, padding_idx=self.pad) print(f'language_embeddings: {self.language_embeddings}')
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last self.enc_lambda = args.enc_lambda logging.warning("Using fixed encoder lambda: {}".format( self.enc_lambda)) logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) # target matching system organization self.oversampling = args.oversampling self.residual = args.residual self.outer = args.outer self.poster = torch.nn.Linear(args.adim, odim * self.oversampling) if self.outer: if self.residual: self.matcher_res = torch.nn.Linear(idim, odim) self.matcher = torch.nn.Linear(odim, odim) else: self.matcher = torch.nn.Linear(odim + idim, odim) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True # lid multitask related adim = args.adim self.lid_odim = 2 # cn and en # src attention self.lid_src_att = MultiHeadedAttention( args.aheads, args.adim, args.transformer_attn_dropout_rate) # self.lid_output_layer = torch.nn.Sequential(torch.nn.Linear(adim, adim), # torch.nn.Tanh(), # torch.nn.Linear(adim, self.lid_odim)) self.lid_output_layer = torch.nn.Linear(adim, self.lid_odim) # here we hack to use lsm loss, but with lsm_weight ZERO self.lid_criterion = LanguageIDMultitakLoss(self.ignore_id, \ normalize_length=args.transformer_length_normalized_loss) self.lid_mtl_alpha = args.lid_mtl_alpha logging.warning("language id multitask training alpha %f" % (self.lid_mtl_alpha)) self.log_lid_mtl_acc = args.log_lid_mtl_acc # reset parameters self.reset_parameters(args)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder_type = getattr(args, 'encoder_type', 'all_add') self.vbs = getattr(args, 'vbs', False) self.noise = getattr(args, 'noise_type', 'none') if self.encoder_type == 'all_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_all_add import MultimodalEncoder elif self.encoder_type == 'proportion_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_proportion_add import MultimodalEncoder elif self.encoder_type == 'vat': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_vat import MultimodalEncoder self.encoder = MultimodalEncoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, visual_dim=args.visual_dim, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, vbs=self.vbs) self.decoder = MultimodalDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited')
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.asr_weight = args.asr_weight self.mt_weight = args.mt_weight self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="st", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC( odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True, ) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = "location" # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 self.aggregation_module = torch.nn.Sequential( torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid()) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last # we frozen params here if args.activated_keys: activated_keys = args.activated_keys.split(',') for name, params in self.named_parameters(): requires_grad = False # by default, we'd like to frozen all params for key in activated_keys: if key in name: requires_grad = True # hit the key, activate this param params.requires_grad = requires_grad else: logging.warning("Not frozen anything.") logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank