def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='asr', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True): """Initialize the transformer.""" chainer.Chain.__init__(self) self.mtlalpha = args.mtlalpha assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.use_label_smoothing = False self.char_list = args.char_list self.space = args.sym_space self.blank = args.sym_blank self.scale_emb = args.adim ** 0.5 self.sos = odim - 1 self.eos = odim - 1 self.subsample = get_subsample(args, mode='asr', arch='transformer') self.ignore_id = ignore_id self.reset_parameters(args) with self.init_scope(): self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, initialW=self.initialW, initial_bias=self.initialB) self.decoder = Decoder(odim, args, initialW=self.initialW, initial_bias=self.initialB) self.criterion = LabelSmoothingLoss(args.lsm_weight, len(args.char_list), args.transformer_length_normalized_loss) if args.mtlalpha > 0.0: if args.ctc_type == 'builtin': logging.info("Using chainer CTC implementation") self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate) elif args.ctc_type == 'warpctc': logging.info("Using warpctc CTC implementation") self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate) else: raise ValueError('ctc_type must be "builtin" or "warpctc": {}' .format(args.ctc_type)) else: self.ctc = None self.dims = args.adim self.odim = odim self.flag_return = flag_return if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None if 'Namespace' in str(type(args)): self.verbose = 0 if 'verbose' not in args else args.verbose else: self.verbose = 0 if args.verbose is None else args.verbose
def __init__(self, idim, odim, args, flag_return=True): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ chainer.Chain.__init__(self) self.mtlalpha = args.mtlalpha assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn") # label smoothing info if args.lsm_type: logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None with self.init_scope(): # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) self.acc = None self.loss = None self.flag_return = flag_return
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) odim = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.output = torch.nn.Linear(256, self.odim) # mean + std pooling self.att = Attention(256) self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) self.do_asr = self.asr_weight > 0 and args.mtlalpha < 1 # cross-attention parameters self.cross_weight = getattr(args, "cross_weight", 0.0) self.cross_self = getattr(args, "cross_self", False) self.cross_src = getattr(args, "cross_src", False) self.cross_operator = getattr(args, "cross_operator", None) self.cross_to_asr = getattr(args, "cross_to_asr", False) self.cross_to_st = getattr(args, "cross_to_st", False) self.num_decoders = getattr(args, "num_decoders", 1) self.wait_k_asr = getattr(args, "wait_k_asr", 0) self.wait_k_st = getattr(args, "wait_k_st", 0) self.cross_src_from = getattr(args, "cross_src_from", "embedding") self.cross_self_from = getattr(args, "cross_self_from", "embedding") self.cross_weight_learnable = getattr(args, "cross_weight_learnable", False) # one-to-many ST experiments self.one_to_many = getattr(args, "one_to_many", False) self.langs_dict = getattr(args, "langs_dict", None) self.lang_tok = getattr(args, "lang_tok", None) self.normalize_before = getattr(args, "normalize_before", True) logging.info(f'self.normalize_before = {self.normalize_before}') # Check parameters if self.cross_operator == 'sum' and self.cross_weight <= 0: assert (not self.cross_to_asr) and (not self.cross_to_st) if self.cross_to_asr or self.cross_to_st: assert self.do_asr assert self.cross_self or self.cross_src assert bool(self.cross_operator) == (self.do_asr and (self.cross_to_asr or self.cross_to_st)) if self.cross_src_from != "embedding" or self.cross_self_from != "embedding": assert self.normalize_before if self.wait_k_asr > 0: assert self.wait_k_st == 0 elif self.wait_k_st > 0: assert self.wait_k_asr == 0 else: assert self.wait_k_asr == 0 assert self.wait_k_st == 0 logging.info("*** Cross attention parameters ***") if self.cross_to_asr: logging.info("| Cross to ASR") if self.cross_to_st: logging.info("| Cross to ST") if self.cross_self: logging.info("| Cross at Self") if self.cross_src: logging.info("| Cross at Source") if self.cross_to_asr or self.cross_to_st: logging.info(f'| Cross operator: {self.cross_operator}') logging.info(f'| Cross sum weight: {self.cross_weight}') if self.cross_src: logging.info(f'| Cross source from: {self.cross_src_from}') if self.cross_self: logging.info(f'| Cross self from: {self.cross_self_from}') logging.info(f'| wait_k_asr = {self.wait_k_asr}') logging.info(f'| wait_k_st = {self.wait_k_st}') if (self.cross_src_from != "embedding" and self.cross_src) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_src_from == embedding for cross at source attention.') if (self.cross_self_from != "embedding" and self.cross_self) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_self_from == embedding for cross at self attention.') self.dual_decoder = DualDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, normalize_before=self.normalize_before, cross_operator=self.cross_operator, cross_weight_learnable=self.cross_weight_learnable, cross_weight=self.cross_weight, cross_self=self.cross_self, cross_src=self.cross_src, cross_to_asr=self.cross_to_asr, cross_to_st=self.cross_to_st ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.idim = idim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0 ) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos if self.lang_tok == "encoder-pre-sum": self.language_embeddings = build_embedding(self.langs_dict, self.idim, padding_idx=self.pad) print(f'language_embeddings: {self.language_embeddings}')
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) if "transformer" in args.etype: if args.enc_block_arch is None: raise ValueError( "Transformer-based blocks in transducer mode should be" "defined individually in the YAML file." "See egs/vivos/asr1/conf/transducer/* for more info.") self.subsample = get_subsample(args, mode="asr", arch="transformer") # 2. use transformer to joint feature maps # transformer without positional encoding self.clayers = repeat( 2, lambda lnum: EncoderLayer( 16, MultiHeadedAttention(4, 16, 0.1), PositionwiseFeedForward(16, 2048, 0.1), dropout_rate=0.1, normalize_before=True, concat_after=False, ), ) self.conv = torch.nn.Sequential( torch.nn.Conv2d(1, 32, kernel_size=(3, 5), stride=(1, 2)), torch.nn.ReLU(), torch.nn.Conv2d(32, 32, kernel_size=(3, 7), stride=(2, 2)), torch.nn.ReLU()) self.encoder = Encoder( idim, args.enc_block_arch, input_layer=args.transformer_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.transformer_enc_self_attn_type, positional_encoding_type=args. transformer_enc_positional_encoding_type, positionwise_activation_type=args. transformer_enc_pw_activation_type, conv_mod_activation_type=args. transformer_enc_conv_mod_activation_type, ) encoder_out = self.encoder.enc_out args.eprojs = self.encoder.enc_out self.most_dom_list = args.enc_block_arch[:] else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for(args, idim, self.subsample) encoder_out = args.eprojs if "transformer" in args.dtype: if args.dec_block_arch is None: raise ValueError( "Transformer-based blocks in transducer mode should be" "defined individually in the YAML file." "See egs/vivos/asr1/conf/transducer/* for more info.") self.decoder = DecoderTT( odim, encoder_out, args.joint_dim, args.dec_block_arch, input_layer=args.transformer_dec_input_layer, repeat_block=args.dec_block_repeat, joint_activation_type=args.joint_activation_type, positionwise_activation_type=args. transformer_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, ) if "transformer" in args.etype: self.most_dom_list += args.dec_block_arch[:] else: self.most_dom_list = args.dec_block_arch[:] else: if args.rnnt_mode == "rnnt-att": self.att = att_for(args) self.dec = DecoderRNNTAtt( args.eprojs, odim, args.dtype, args.dlayers, args.dunits, blank_id, self.att, args.dec_embed_dim, args.joint_dim, args.joint_activation_type, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) else: self.dec = DecoderRNNT( args.eprojs, odim, args.dtype, args.dlayers, args.dunits, blank_id, args.dec_embed_dim, args.joint_dim, args.joint_activation_type, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) if hasattr(self, "most_dom_list"): self.most_dom_dim = sorted( Counter(d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d).most_common(), key=lambda x: x[0], reverse=True, )[0][0] self.etype = args.etype self.dtype = args.dtype self.rnnt_mode = args.rnnt_mode self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer if self.dtype == "transformer": decoder = self.decoder else: decoder = self.dec self.error_calculator = ErrorCalculatorTransducer( decoder, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) # target matching system organization self.oversampling = args.oversampling self.residual = args.residual self.outer = args.outer self.poster = torch.nn.Linear(args.adim, odim * self.oversampling) if self.outer: if self.residual: self.matcher_res = torch.nn.Linear(idim, odim) self.matcher = torch.nn.Linear(odim, odim) else: self.matcher = torch.nn.Linear(odim + idim, odim) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) if "transformer" in args.etype: if args.enc_block_arch is None: raise ValueError( "Transformer-based blocks in transducer mode should be" "defined individually in the YAML file." "See egs/vivos/asr1/conf/transducer/* for more info.") self.subsample = get_subsample(args, mode="asr", arch="transformer") self.encoder = Encoder( idim, args.enc_block_arch, input_layer=args.transformer_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.transformer_enc_self_attn_type, positional_encoding_type=args. transformer_enc_positional_encoding_type, positionwise_activation_type=args. transformer_enc_pw_activation_type, conv_mod_activation_type=args. transformer_enc_conv_mod_activation_type, ) encoder_out = self.encoder.enc_out args.eprojs = self.encoder.enc_out self.most_dom_list = args.enc_block_arch[:] else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for(args, idim, self.subsample) encoder_out = args.eprojs if "transformer" in args.dtype: if args.dec_block_arch is None: raise ValueError( "Transformer-based blocks in transducer mode should be" "defined individually in the YAML file." "See egs/vivos/asr1/conf/transducer/* for more info.") self.decoder = DecoderTT( odim, encoder_out, args.joint_dim, args.dec_block_arch, input_layer=args.transformer_dec_input_layer, repeat_block=args.dec_block_repeat, joint_activation_type=args.joint_activation_type, positionwise_activation_type=args. transformer_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, ) if "transformer" in args.etype: self.most_dom_list += args.dec_block_arch[:] else: self.most_dom_list = args.dec_block_arch[:] else: if args.rnnt_mode == "rnnt-att": self.att = att_for(args) self.dec = DecoderRNNTAtt( args.eprojs, odim, args.dtype, args.dlayers, args.dunits, blank_id, self.att, args.dec_embed_dim, args.joint_dim, args.joint_activation_type, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) else: self.dec = DecoderRNNT( args.eprojs, odim, args.dtype, args.dlayers, args.dunits, blank_id, args.dec_embed_dim, args.joint_dim, args.joint_activation_type, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) if hasattr(self, "most_dom_list"): self.most_dom_dim = sorted( Counter(d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d).most_common(), key=lambda x: x[0], reverse=True, )[0][0] self.etype = args.etype self.dtype = args.dtype self.rnnt_mode = args.rnnt_mode self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() if training: self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculatorTransducer if self.dtype == "transformer": decoder = self.decoder else: decoder = self.dec self.error_calculator = ErrorCalculatorTransducer( decoder, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) args = fill_missing_args(args, self.add_arguments) self.is_rnnt = True self.transducer_weight = args.transducer_weight self.use_aux_task = (True if (args.aux_task_type is not None and training) else False) self.use_aux_ctc = args.aux_ctc and training self.aux_ctc_weight = args.aux_ctc_weight self.use_aux_cross_entropy = args.aux_cross_entropy and training self.aux_cross_entropy_weight = args.aux_cross_entropy_weight if self.use_aux_task: n_layers = ((len(args.enc_block_arch) * args.enc_block_repeat - 1) if args.enc_block_arch is not None else (args.elayers - 1)) aux_task_layer_list = valid_aux_task_layer_list( args.aux_task_layer_list, n_layers, ) else: aux_task_layer_list = [] if "custom" in args.etype: if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.subsample = get_subsample(args, mode="asr", arch="transformer") self.encoder = CustomEncoder( idim, args.enc_block_arch, input_layer=args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args. custom_enc_positional_encoding_type, positionwise_activation_type=args. custom_enc_pw_activation_type, conv_mod_activation_type=args. custom_enc_conv_mod_activation_type, aux_task_layer_list=aux_task_layer_list, ) encoder_out = self.encoder.enc_out self.most_dom_list = args.enc_block_arch[:] else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for( args, idim, self.subsample, aux_task_layer_list=aux_task_layer_list, ) encoder_out = args.eprojs if "custom" in args.dtype: if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, input_layer=args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args. custom_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, ) decoder_out = self.decoder.dunits if "custom" in args.etype: self.most_dom_list += args.dec_block_arch[:] else: self.most_dom_list = args.dec_block_arch[:] else: self.dec = DecoderRNNT( odim, args.dtype, args.dlayers, args.dunits, blank_id, args.dec_embed_dim, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) decoder_out = args.dunits self.joint_network = JointNetwork(odim, encoder_out, decoder_out, args.joint_dim, args.joint_activation_type) if hasattr(self, "most_dom_list"): self.most_dom_dim = sorted( Counter(d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d).most_common(), key=lambda x: x[0], reverse=True, )[0][0] self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.error_calculator = None self.default_parameters(args) if training: self.criterion = TransLoss(args.trans_type, self.blank_id) decoder = self.decoder if self.dtype == "custom" else self.dec if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( decoder, self.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) if self.use_aux_task: self.auxiliary_task = AuxiliaryTask( decoder, self.joint_network, self.criterion, args.aux_task_type, args.aux_task_weight, encoder_out, args.joint_dim, ) if self.use_aux_ctc: self.aux_ctc = ctc_for( Namespace( num_encs=1, eprojs=encoder_out, dropout_rate=args.aux_ctc_dropout_rate, ctc_type="warpctc", ), odim, ) if self.use_aux_cross_entropy: self.aux_decoder_output = torch.nn.Linear(decoder_out, odim) self.aux_cross_entropy = LabelSmoothingLoss( odim, ignore_id, args.aux_cross_entropy_smoothing) self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args): """Initialize multi-speaker E2E module.""" torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir self.reporter = Reporter() self.num_spkrs = args.num_spkrs self.spa = args.spa self.pit = PIT(self.num_spkrs) # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode='asr', arch='rnn_mix') # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility # Relative importing because of using python3 syntax from espnet.nets.pytorch_backend.frontends.feature_transform \ import feature_transform_for from espnet.nets.pytorch_backend.frontends.frontend \ import frontend_for self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for( args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim, reduce=False) # attention num_att = self.num_spkrs if args.spa else 1 self.att = att_for(args, num_att) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # weight initialization self.init_like_chainer() # options for beam search if 'report_cer' in vars(args) and (args.report_cer or args.report_wer): recog_args = { 'beam_size': args.beam_size, 'penalty': args.penalty, 'ctc_weight': args.ctc_weight, 'maxlenratio': args.maxlenratio, 'minlenratio': args.minlenratio, 'lm_weight': args.lm_weight, 'rnnlm': args.rnnlm, 'nbest': args.nbest, 'space': args.sym_space, 'blank': args.sym_blank } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, mono_odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method self.outdir = args.outdir # target matching system organization self.oversampling = args.oversampling self.residual = args.residual self.outer = args.outer self.poster = torch.nn.Linear(args.eprojs, odim * self.oversampling) self.poster_mono = torch.nn.Linear(args.eprojs, mono_odim * self.oversampling) # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.sos_mono = mono_odim - 1 self.eos_mono = mono_odim - 1 self.odim = odim self.mono_odim = mono_odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="rnn") self.reporter = Reporter() # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for( args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # weight initialization if args.initializer == "lecun": self.init_like_chainer() elif args.initializer == "orthogonal": self.init_orthogonal() elif args.initializer == "xavier": self.init_xavier() else: raise NotImplementedError("unknown initializer: " + args.initializer) if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
arch_subsample = "transformer" else: raise ValueError("Unsupported model module: %s" % model_module) model_class = dynamic_import(model_module) model_class.add_arguments(parser) args = parser.parse_args(cmd_args) # subsampling info if hasattr(args, "etype") and args.etype.startswith("vgg"): # Subsampling is not performed for vgg*. # It is performed in max pooling layers at CNN. min_io_ratio = 4 else: subsample = get_subsample(args, mode=args.mode_subsample, arch=arch_subsample) # the minimum input-output length ratio for all samples min_io_ratio = reduce(mul, subsample) # load dictionary with open(args.data_json, "rb") as f: j = json.load(f)["utts"] # remove samples with IO ratio smaller than `min_io_ratio` for key in list(j.keys()): ilen = j[key]["input"][0]["shape"][0] olen = min(x["shape"][0] for x in j[key]["output"]) if float(ilen) - float(olen) * min_io_ratio < args.min_io_delta: j.pop(key) print("'{}' removed".format(key))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='mt', arch='transformer') self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( 'When using tie_src_tgt_embedding, idim and odim must be equal.' ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.normalize_length = args.transformer_length_normalized_loss # for PPL # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim if args.report_bleu: from espnet.nets.e2e_mt_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.report_bleu) else: self.error_calculator = None self.rnnlm = None # multilingual NMT related self.multilingual = args.multilingual
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder_type = getattr(args, 'encoder_type', 'all_add') self.vbs = getattr(args, 'vbs', False) self.noise = getattr(args, 'noise_type', 'none') if self.encoder_type == 'all_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_all_add import MultimodalEncoder elif self.encoder_type == 'proportion_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_proportion_add import MultimodalEncoder elif self.encoder_type == 'vat': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_vat import MultimodalEncoder self.encoder = MultimodalEncoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, visual_dim=args.visual_dim, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, vbs=self.vbs) self.decoder = MultimodalDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
def __init__(self, idim, odim, args): """Initialize transducer modules. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.rnnt_mode = args.rnnt_mode self.etype = args.etype self.verbose = args.verbose self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() self.beam_size = args.beam_size # note that eos is the same as sos (equivalent ID) self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode='asr', arch='rnn-t') if args.use_frontend: # Relative importing because of using python3 syntax from espnet.nets.pytorch_backend.frontends.feature_transform \ import feature_transform_for from espnet.nets.pytorch_backend.frontends.frontend \ import frontend_for self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for( args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) if args.rnnt_mode == 'rnnt-att': # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.att) else: # prediction self.dec = decoder_for(args, odim) # weight initialization self.init_like_chainer() # options for beam search if 'report_cer' in vars(args) and (args.report_cer or args.report_wer): recog_args = { 'beam_size': args.beam_size, 'nbest': args.nbest, 'space': args.sym_space, 'score_norm_transducer': args.score_norm_transducer } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.logzero = -10000000000.0 self.rnnlm = None self.loss = None
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # gs534 - word vocab bpe = len(self.char_list) > 100 # hack here for bpe flag self.vocabulary = Vocabulary(args.dictfile, bpe) if args.dictfile != '' else None # gs534 - create lexicon tree lextree = None self.meeting_KB = None self.n_KBs = getattr(args, 'dynamicKBs', 0) pretrain_emb = [] if args.meetingKB and args.meetingpath != '': if self.n_KBs == 0 or not os.path.isdir(os.path.join(args.meetingpath, 'split_0')): self.meeting_KB = KBmeeting(self.vocabulary, args.meetingpath, args.char_list, bpe) else: # arrange multiple KBs self.meeting_KB = [] for i in range(self.n_KBs): self.meeting_KB.append(KBmeeting(self.vocabulary, os.path.join(args.meetingpath, 'split_{}'.format(i)), args.char_list, bpe)) # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist( odim, args.lsm_type, transcript=args.train_json ) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for(args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist, meetingKB=self.meeting_KB[0] if isinstance(self.meeting_KB, list) else self.meeting_KB) # weight initialization self.init_from = getattr(args, 'init_full_model', None) self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL self.reset_parameters(args) self.adim = args.adim self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) self.rnnlm = None # multilingual MT related self.multilingual = args.multilingual
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method self.outdir = args.outdir # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="rnn") self.reporter = Reporter() # ICT related self.scheme = args.mixup_scheme self.consistency_weight = args.consistency_weight self.consistency_rampup_starts = args.consistency_rampup_starts self.consistency_rampup_ends = args.consistency_rampup_ends self.mixup_alpha = args.mixup_alpha # if True, print out student model accuracy self.show_student_model_acc = args.show_student_model_acc # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist( odim, args.lsm_type, transcript=args.train_json ) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for(args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, odim, self.subsample) self.ema_enc = encoder_for(args, idim, odim, self.subsample) for param in self.ema_enc.parameters(): param.detach_() # leave ctc for future works # self.ctc = ctc_for(args, odim) # weight initialization if args.initializer == "lecun": self.init_like_chainer() elif args.initializer == "orthogonal": self.init_orthogonal() else: raise NotImplementedError( "unknown initializer: " + args.initializer ) if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__( self, idim: int, odim: int, args: Namespace, ignore_id: int = -1, blank_id: int = 0, training: bool = True, ): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) args = fill_missing_args(args, self.add_arguments) self.is_transducer = True self.use_auxiliary_enc_outputs = (True if ( training and args.use_aux_transducer_loss) else False) self.subsample = get_subsample( args, mode="asr", arch="transformer" if args.etype == "custom" else "rnn-t") if self.use_auxiliary_enc_outputs: n_layers = (((len(args.enc_block_arch) * args.enc_block_repeat) - 1) if args.enc_block_arch is not None else (args.elayers - 1)) aux_enc_output_layers = valid_aux_encoder_output_layers( args.aux_transducer_loss_enc_output_layers, n_layers, args.use_symm_kl_div_loss, self.subsample, ) else: aux_enc_output_layers = [] if args.etype == "custom": if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.encoder = CustomEncoder( idim, args.enc_block_arch, input_layer=args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args. custom_enc_positional_encoding_type, positionwise_activation_type=args. custom_enc_pw_activation_type, conv_mod_activation_type=args. custom_enc_conv_mod_activation_type, aux_enc_output_layers=aux_enc_output_layers, ) encoder_out = self.encoder.enc_out else: self.enc = encoder_for( args, idim, self.subsample, aux_enc_output_layers=aux_enc_output_layers, ) encoder_out = args.eprojs if args.dtype == "custom": if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, input_layer=args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args. custom_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = self.decoder.dunits else: self.dec = RNNDecoder( odim, args.dtype, args.dlayers, args.dunits, args.dec_embed_dim, dropout_rate=args.dropout_rate_decoder, dropout_rate_embed=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = args.dunits self.transducer_tasks = TransducerTasks( encoder_out, decoder_out, args.joint_dim, odim, joint_activation_type=args.joint_activation_type, transducer_loss_weight=args.transducer_weight, ctc_loss=args.use_ctc_loss, ctc_loss_weight=args.ctc_loss_weight, ctc_loss_dropout_rate=args.ctc_loss_dropout_rate, lm_loss=args.use_lm_loss, lm_loss_weight=args.lm_loss_weight, lm_loss_smoothing_rate=args.lm_loss_smoothing_rate, aux_transducer_loss=args.use_aux_transducer_loss, aux_transducer_loss_weight=args.aux_transducer_loss_weight, aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim, aux_trans_loss_mlp_dropout_rate=args. aux_transducer_loss_mlp_dropout_rate, symm_kl_div_loss=args.use_symm_kl_div_loss, symm_kl_div_loss_weight=args.symm_kl_div_loss_weight, fastemit_lambda=args.fastemit_lambda, blank_id=blank_id, ignore_id=ignore_id, training=training, ) if training and (args.report_cer or args.report_wer): self.error_calculator = ErrorCalculator( self.decoder if args.dtype == "custom" else self.dec, self.transducer_tasks.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.default_parameters(args) self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="mt", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.embed = torch.nn.Embedding(idim, args.eunits, padding_idx=self.pad) self.dropout = torch.nn.Dropout(p=args.dropout_rate) self.enc = encoder_for(args, args.eunits, self.subsample) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) if args.eunits != args.dunits: raise ValueError( "When using tie_src_tgt_embedding, eunits and dunits must be equal." ) self.embed.weight = self.dec.embed.weight # tie emeddings and the classfier if args.tie_classifier: if args.context_residual: raise ValueError( "When using tie_classifier, context_residual must be turned off." ) self.dec.output.weight = self.dec.embed.weight # weight initialization self.init_like_fairseq() # options for beam search if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idims, odim, args): """Initialize this class with python-level args. Args: idims (list): list of the number of an input feature dim. odim (int): The number of output vocab. args (Namespace): arguments """ super(E2E, self).__init__() torch.nn.Module.__init__(self) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() self.num_encs = args.num_encs self.share_ctc = args.share_ctc # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample_list = get_subsample(args, mode="asr", arch="rnn_mulenc") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # speech translation related self.replace_sos = getattr(args, "replace_sos", False) # use getattr to keep compatibility self.frontend = None # encoder self.enc = encoder_for(args, idims, self.subsample_list) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # hierarchical attention network han = att_for(args, han_mode=True) self.att.append(han) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) if args.mtlalpha > 0 and self.num_encs > 1: # weights-ctc, # e.g. ctc_loss = w_1*ctc_1_loss + w_2 * ctc_2_loss + w_N * ctc_N_loss self.weights_ctc_train = args.weights_ctc_train / np.sum( args.weights_ctc_train) # normalize self.weights_ctc_dec = args.weights_ctc_dec / np.sum( args.weights_ctc_dec) # normalize logging.info("ctc weights (training during training): " + " ".join([str(x) for x in self.weights_ctc_train])) logging.info("ctc weights (decoding during training): " + " ".join([str(x) for x in self.weights_ctc_dec])) else: self.weights_ctc_dec = [1.0] self.weights_ctc_train = [1.0] # weight initialization self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, "ctc_weights_dec": self.weights_ctc_dec, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.asr_weight = args.asr_weight self.mt_weight = args.mt_weight self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="st", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC( odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True, ) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = "location" # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.etype == 'transformer': self.encoder = Encoder(idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args. transformer_attn_dropout_rate_encoder) self.subsample = [1] else: self.subsample = get_subsample(args, mode='asr', arch='rnn-t') self.encoder = encoder_for(args, idim, self.subsample) if args.dtype == 'transformer': self.decoder = Decoder( odim=odim, jdim=args.joint_dim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_dec_input_layer, dropout_rate=args.dropout_rate_decoder, positional_dropout_rate=args.dropout_rate_decoder, attention_dropout_rate=args. transformer_attn_dropout_rate_decoder) else: if args.etype == 'transformer': args.eprojs = args.adim if args.rnnt_mode == 'rnnt-att': self.att = att_for(args) self.decoder = decoder_for(args, odim, self.att) else: self.decoder = decoder_for(args, odim) self.etype = args.etype self.dtype = args.dtype self.rnnt_mode = args.rnnt_mode self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.adim = args.adim self.reporter = Reporter() self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculatorTrans self.error_calculator = ErrorCalculatorTrans(self.decoder, args) else: self.error_calculator = None self.logzero = -10000000000.0 self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for( args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # weight initialization self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) self.is_rnnt = True if "custom" in args.etype: if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.subsample = get_subsample(args, mode="asr", arch="transformer") self.encoder = CustomEncoder( idim, args.enc_block_arch, input_layer=args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args. custom_enc_positional_encoding_type, positionwise_activation_type=args. custom_enc_pw_activation_type, conv_mod_activation_type=args. custom_enc_conv_mod_activation_type, ) encoder_out = self.encoder.enc_out self.most_dom_list = args.enc_block_arch[:] else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for(args, idim, self.subsample) encoder_out = args.eprojs if "custom" in args.dtype: if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, input_layer=args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args. custom_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, ) decoder_out = self.decoder.dunits if "custom" in args.etype: self.most_dom_list += args.dec_block_arch[:] else: self.most_dom_list = args.dec_block_arch[:] else: self.dec = DecoderRNNT( odim, args.dtype, args.dlayers, args.dunits, blank_id, args.dec_embed_dim, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) decoder_out = args.dunits self.joint_network = JointNetwork(odim, encoder_out, decoder_out, args.joint_dim, args.joint_activation_type) if hasattr(self, "most_dom_list"): self.most_dom_dim = sorted( Counter(d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d).most_common(), key=lambda x: x[0], reverse=True, )[0][0] self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() if training: self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if training and (args.report_cer or args.report_wer): self.error_calculator = ErrorCalculator( self.decoder if self.dtype == "custom" else self.dec, self.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.loss = None self.rnnlm = None