def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, center_len=args.transformer_encoder_center_chunk_len, left_len=args.transformer_encoder_left_chunk_len, hop_len=args.transformer_encoder_hop_len, right_len=args.transformer_encoder_right_chunk_len, abs_pos=args.transformer_encoder_abs_embed, rel_pos=args.transformer_encoder_rel_embed, use_mem=args.transformer_encoder_use_memory, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer or args.mtlalpha > 0.0: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, decoder: AbsDecoder, ctc: CTC, rnnt_decoder: None, ctc_weight: float = 0.5, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert rnnt_decoder is None, "Not implemented" super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.adddiontal_utt_mvn = None self.preencoder = preencoder self.encoder = encoder self.decoder = decoder if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.rnnt_decoder = rnnt_decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='asr', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True): chainer.Chain.__init__(self) self.mtlalpha = args.mtlalpha assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" if args.transformer_attn_dropout_rate is None: self.dropout = args.dropout_rate else: self.dropout = args.transformer_attn_dropout_rate self.use_label_smoothing = False self.char_list = args.char_list self.space = args.sym_space self.blank = args.sym_blank self.scale_emb = args.adim**0.5 self.sos = odim - 1 self.eos = odim - 1 self.subsample = [0] self.ignore_id = ignore_id self.reset_parameters(args) with self.init_scope(): self.encoder = Encoder(idim, args, initialW=self.initialW, initial_bias=self.initialB) self.decoder = Decoder(odim, args, initialW=self.initialW, initial_bias=self.initialB) self.criterion = LabelSmoothingLoss( args.lsm_weight, len(args.char_list), args.transformer_length_normalized_loss) if args.mtlalpha > 0.0: if args.ctc_type == 'builtin': logging.info("Using chainer CTC implementation") self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate) elif args.ctc_type == 'warpctc': logging.info("Using warpctc CTC implementation") self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate) else: raise ValueError( 'ctc_type must be "builtin" or "warpctc": {}'.format( args.ctc_type)) else: self.ctc = None self.dims = args.adim self.odim = odim self.flag_return = flag_return if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None if 'Namespace' in str(type(args)): self.verbose = 0 if 'verbose' not in args else args.verbose else: self.verbose = 0 if args.verbose is None else args.verbose
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = False, report_wer: bool = False, sym_space: str = "<space>", sym_blank: str = "<blank>", pred_masked_weight: float = 1.0, pred_nomask_weight: float = 0.0, loss_weights: float = 0.0, ): assert check_argument_types() super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.encoder = encoder self.criterion_att = HubertPretrainLoss( pred_masked_weight, pred_nomask_weight, loss_weights, ) self.pred_masked_weight = pred_masked_weight self.pred_nomask_weight = pred_nomask_weight self.loss_weights = loss_weights if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None
def test_error_calculator_nospace(tmpdir): from espnet.nets.e2e_asr_common import ErrorCalculator space = "<space>" blank = "<blank>" char_list = [blank, 'a', 'e', 'i', 'o', 'u'] ys_pad = [np.random.randint(0, len(char_list), x) for x in range(120, 150, 5)] ys_hat = [np.random.randint(0, len(char_list), x) for x in range(120, 150, 5)] cer, wer = True, True ec = ErrorCalculator(char_list, space, blank, cer, wer) cer_ctc_val = ec(ys_pad, ys_hat, is_ctc=True) _cer, _wer = ec(ys_pad, ys_hat) assert cer_ctc_val is not None assert _cer is not None assert _wer is not None
def test_error_calculator(tmpdir, typ): from espnet.nets.e2e_asr_common import ErrorCalculator space = "<space>" blank = "<blank>" char_list = [blank, space, "a", "e", "i", "o", "u"] ys_pad = [ np.random.randint(0, len(char_list), x) for x in range(120, 150, 5) ] ys_hat = [ np.random.randint(0, len(char_list), x) for x in range(120, 150, 5) ] if typ == "ctc": cer, wer = False, False elif typ == "wer": cer, wer = False, True elif typ == "cer": cer, wer = True, False else: cer, wer = True, True ec = ErrorCalculator(char_list, space, blank, cer, wer) if typ == "ctc": cer_ctc_val = ec(ys_pad, ys_hat, is_ctc=True) _cer, _wer = ec(ys_pad, ys_hat) assert cer_ctc_val is not None assert _cer is None assert _wer is None elif typ == "wer": _cer, _wer = ec(ys_pad, ys_hat) assert _cer is None assert _wer is not None elif typ == "cer": _cer, _wer = ec(ys_pad, ys_hat) assert _cer is not None assert _wer is None else: cer_ctc_val = ec(ys_pad, ys_hat, is_ctc=True) _cer, _wer = ec(ys_pad, ys_hat) assert cer_ctc_val is not None assert _cer is not None assert _wer is not None
def test_error_calculator(tmpdir, typ): from espnet.nets.e2e_asr_common import ErrorCalculator space = "<space>" blank = "<blank>" char_list = [blank, space, 'a', 'e', 'i', 'o', 'u'] ys_pad = [np.random.randint(0, 7, x) for x in range(120, 150, 5)] ys_hat = [np.random.randint(0, 7, x) for x in range(120, 150, 5)] if typ == 'ctc': cer, wer = False, False elif typ == 'wer': cer, wer = False, True elif typ == 'cer': cer, wer = True, False else: cer, wer = True, True ec = ErrorCalculator(char_list, space, blank, cer, wer) if typ == 'ctc': cer_ctc_val = ec(ys_pad, ys_hat, is_ctc=True) _cer, _wer = ec(ys_pad, ys_hat) assert cer_ctc_val is not None assert _cer is None assert _wer is None elif typ == 'wer': _cer, _wer = ec(ys_pad, ys_hat) assert _cer is None assert _wer is not None elif typ == 'cer': _cer, _wer = ec(ys_pad, ys_hat) assert _cer is not None assert _wer is None else: cer_ctc_val = ec(ys_pad, ys_hat, is_ctc=True) _cer, _wer = ec(ys_pad, ys_hat) assert cer_ctc_val is not None assert _cer is not None assert _wer is not None
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: MLMDecoder, ctc: CTC, joint_network: Optional[torch.nn.Module] = None, ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", sym_mask: str = "<mask>", extract_feats_in_collect_stats: bool = True, ): assert check_argument_types() super().__init__( vocab_size=vocab_size, token_list=token_list, frontend=frontend, specaug=specaug, normalize=normalize, preencoder=preencoder, encoder=encoder, postencoder=postencoder, decoder=decoder, ctc=ctc, joint_network=joint_network, ctc_weight=ctc_weight, interctc_weight=interctc_weight, ignore_id=ignore_id, lsm_weight=lsm_weight, length_normalized_loss=length_normalized_loss, report_cer=report_cer, report_wer=report_wer, sym_space=sym_space, sym_blank=sym_blank, extract_feats_in_collect_stats=extract_feats_in_collect_stats, ) # Add <mask> and override inherited fields token_list.append(sym_mask) vocab_size += 1 self.vocab_size = vocab_size self.mask_token = vocab_size - 1 self.token_list = token_list.copy() # MLM loss del self.criterion_att self.criterion_mlm = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) self.error_calculator = None if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer)
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, rnnt_decoder: None, ctc_weight: float = 0.5, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", extract_feats_in_collect_stats: bool = True, ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert rnnt_decoder is None, "Not implemented" super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder # we set self.decoder = None in the CTC mode since # self.decoder parameters were never used and PyTorch complained # and threw an Exception in the multi-GPU experiment. # thanks Jeff Farris for pointing out the issue. if ctc_weight == 1.0: self.decoder = None else: self.decoder = decoder if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.rnnt_decoder = rnnt_decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def __init__(self, idim, odim, args, ignore_id=-1, flag_return=True): """Initialize the transformer.""" chainer.Chain.__init__(self) self.mtlalpha = args.mtlalpha assert 0 <= self.mtlalpha <= 1, "mtlalpha must be [0,1]" if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.use_label_smoothing = False self.char_list = args.char_list self.space = args.sym_space self.blank = args.sym_blank self.scale_emb = args.adim**0.5 self.sos = odim - 1 self.eos = odim - 1 self.subsample = get_subsample(args, mode="asr", arch="transformer") self.ignore_id = ignore_id self.reset_parameters(args) with self.init_scope(): self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, initialW=self.initialW, initial_bias=self.initialB, ) self.decoder = Decoder(odim, args, initialW=self.initialW, initial_bias=self.initialB) self.criterion = LabelSmoothingLoss( args.lsm_weight, len(args.char_list), args.transformer_length_normalized_loss, ) if args.mtlalpha > 0.0: if args.ctc_type == "builtin": logging.info("Using chainer CTC implementation") self.ctc = ctc.CTC(odim, args.adim, args.dropout_rate) elif args.ctc_type == "warpctc": logging.info("Using warpctc CTC implementation") self.ctc = ctc.WarpCTC(odim, args.adim, args.dropout_rate) else: raise ValueError( 'ctc_type must be "builtin" or "warpctc": {}'.format( args.ctc_type)) else: self.ctc = None self.dims = args.adim self.odim = odim self.flag_return = flag_return if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None if "Namespace" in str(type(args)): self.verbose = 0 if "verbose" not in args else args.verbose else: self.verbose = 0 if args.verbose is None else args.verbose
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder_type = getattr(args, 'encoder_type', 'all_add') self.vbs = getattr(args, 'vbs', False) self.noise = getattr(args, 'noise_type', 'none') if self.encoder_type == 'all_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_all_add import MultimodalEncoder elif self.encoder_type == 'proportion_add': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_proportion_add import MultimodalEncoder elif self.encoder_type == 'vat': from espnet.nets.pytorch_backend.transformer.multimodal_encoder_vat import MultimodalEncoder self.encoder = MultimodalEncoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, visual_dim=args.visual_dim, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, vbs=self.vbs) self.decoder = MultimodalDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) self.do_asr = self.asr_weight > 0 and args.mtlalpha < 1 # cross-attention parameters self.cross_weight = getattr(args, "cross_weight", 0.0) self.cross_self = getattr(args, "cross_self", False) self.cross_src = getattr(args, "cross_src", False) self.cross_operator = getattr(args, "cross_operator", None) self.cross_to_asr = getattr(args, "cross_to_asr", False) self.cross_to_st = getattr(args, "cross_to_st", False) self.num_decoders = getattr(args, "num_decoders", 1) self.wait_k_asr = getattr(args, "wait_k_asr", 0) self.wait_k_st = getattr(args, "wait_k_st", 0) self.cross_src_from = getattr(args, "cross_src_from", "embedding") self.cross_self_from = getattr(args, "cross_self_from", "embedding") self.cross_weight_learnable = getattr(args, "cross_weight_learnable", False) # one-to-many ST experiments self.one_to_many = getattr(args, "one_to_many", False) self.langs_dict = getattr(args, "langs_dict", None) self.lang_tok = getattr(args, "lang_tok", None) self.normalize_before = getattr(args, "normalize_before", True) logging.info(f'self.normalize_before = {self.normalize_before}') # Check parameters if self.cross_operator == 'sum' and self.cross_weight <= 0: assert (not self.cross_to_asr) and (not self.cross_to_st) if self.cross_to_asr or self.cross_to_st: assert self.do_asr assert self.cross_self or self.cross_src assert bool(self.cross_operator) == (self.do_asr and (self.cross_to_asr or self.cross_to_st)) if self.cross_src_from != "embedding" or self.cross_self_from != "embedding": assert self.normalize_before if self.wait_k_asr > 0: assert self.wait_k_st == 0 elif self.wait_k_st > 0: assert self.wait_k_asr == 0 else: assert self.wait_k_asr == 0 assert self.wait_k_st == 0 logging.info("*** Cross attention parameters ***") if self.cross_to_asr: logging.info("| Cross to ASR") if self.cross_to_st: logging.info("| Cross to ST") if self.cross_self: logging.info("| Cross at Self") if self.cross_src: logging.info("| Cross at Source") if self.cross_to_asr or self.cross_to_st: logging.info(f'| Cross operator: {self.cross_operator}') logging.info(f'| Cross sum weight: {self.cross_weight}') if self.cross_src: logging.info(f'| Cross source from: {self.cross_src_from}') if self.cross_self: logging.info(f'| Cross self from: {self.cross_self_from}') logging.info(f'| wait_k_asr = {self.wait_k_asr}') logging.info(f'| wait_k_st = {self.wait_k_st}') if (self.cross_src_from != "embedding" and self.cross_src) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_src_from == embedding for cross at source attention.') if (self.cross_self_from != "embedding" and self.cross_self) and (not self.normalize_before): logging.warning(f'WARNING: Resort to using self.cross_self_from == embedding for cross at self attention.') self.dual_decoder = DualDecoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, normalize_before=self.normalize_before, cross_operator=self.cross_operator, cross_weight_learnable=self.cross_weight_learnable, cross_weight=self.cross_weight, cross_self=self.cross_self, cross_src=self.cross_src, cross_to_asr=self.cross_to_asr, cross_to_st=self.cross_to_st ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.idim = idim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0 ) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos if self.lang_tok == "encoder-pre-sum": self.language_embeddings = build_embedding(self.langs_dict, self.idim, padding_idx=self.pad) print(f'language_embeddings: {self.language_embeddings}')
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], enh: Optional[AbsEnhancement], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], encoder: AbsEncoder, decoder: AbsDecoder, ctc: CTC, rnnt_decoder: None, ctc_weight: float = 0.5, ignore_id: int = -1, lsm_weight: float = 0.0, enh_weight: float = 0.5, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= enh_weight <= 1.0, ctc_weight assert rnnt_decoder is None, "Not implemented" super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.enh_weight = enh_weight self.token_list = token_list.copy() self.enh_model = enh self.num_spk = enh.num_spk self.mask_type = getattr(self.enh_model, "mask_type", None) # get loss type for model training self.loss_type = getattr(self.enh_model, "loss_type", None) assert self.loss_type in ( # mse_loss(predicted_mask, target_label) "mask_mse", # mse_loss(enhanced_magnitude_spectrum, target_magnitude_spectrum) "magnitude", # mse_loss(enhanced_complex_spectrum, target_complex_spectrum) "spectrum", # si_snr(enhanced_waveform, target_waveform) "si_snr", ), self.loss_type self.frontend = frontend self.specaug = specaug self.normalize = normalize self.encoder = encoder self.decoder = decoder if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.rnnt_decoder = rnnt_decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator(token_list, sym_space, sym_blank, report_cer, report_wer) else: self.error_calculator = None # TODO(Jing): find out the -1 or 0 here # self.idx_blank = token_list.index(sym_blank) # 0 self.idx_blank = -1
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, ctc: CTC, joint_network: Optional[torch.nn.Module], ctc_weight: float = 0.5, interctc_weight: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", extract_feats_in_collect_stats: bool = True, ): assert check_argument_types() assert 0.0 <= ctc_weight <= 1.0, ctc_weight assert 0.0 <= interctc_weight < 1.0, interctc_weight super().__init__() # note that eos is the same as sos (equivalent ID) self.blank_id = 0 self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.ignore_id = ignore_id self.ctc_weight = ctc_weight self.interctc_weight = interctc_weight self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder if not hasattr(self.encoder, "interctc_use_conditioning"): self.encoder.interctc_use_conditioning = False if self.encoder.interctc_use_conditioning: self.encoder.conditioning_layer = torch.nn.Linear( vocab_size, self.encoder.output_size() ) self.use_transducer_decoder = joint_network is not None self.error_calculator = None if self.use_transducer_decoder: from warprnnt_pytorch import RNNTLoss self.decoder = decoder self.joint_network = joint_network self.criterion_transducer = RNNTLoss( blank=self.blank_id, fastemit_lambda=0.0, ) if report_cer or report_wer: self.error_calculator_trans = ErrorCalculatorTransducer( decoder, joint_network, token_list, sym_space, sym_blank, report_cer=report_cer, report_wer=report_wer, ) else: self.error_calculator_trans = None if self.ctc_weight != 0: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) else: # we set self.decoder = None in the CTC mode since # self.decoder parameters were never used and PyTorch complained # and threw an Exception in the multi-GPU experiment. # thanks Jeff Farris for pointing out the issue. if ctc_weight == 1.0: self.decoder = None else: self.decoder = decoder self.criterion_att = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) if report_cer or report_wer: self.error_calculator = ErrorCalculator( token_list, sym_space, sym_blank, report_cer, report_wer ) if ctc_weight == 0.0: self.ctc = None else: self.ctc = ctc self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited')
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last self.enc_lambda = args.enc_lambda logging.warning("Using fixed encoder lambda: {}".format( self.enc_lambda)) logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) # gated add module self.vectorize_lambda = args.vectorize_lambda lambda_dim = args.adim if self.vectorize_lambda else 1 self.aggregation_module = torch.nn.Sequential( torch.nn.Linear(2 * args.adim, lambda_dim), torch.nn.Sigmoid()) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last # we frozen params here if args.activated_keys: activated_keys = args.activated_keys.split(',') for name, params in self.named_parameters(): requires_grad = False # by default, we'd like to frozen all params for key in activated_keys: if key in name: requires_grad = True # hit the key, activate this param params.requires_grad = requires_grad else: logging.warning("Not frozen anything.") logging.warning( "Model total size: {}M, requires_grad size: {}M".format( self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.remove_blank_in_ctc_mode = True # lid multitask related adim = args.adim self.lid_odim = 2 # cn and en # src attention self.lid_src_att = MultiHeadedAttention( args.aheads, args.adim, args.transformer_attn_dropout_rate) # self.lid_output_layer = torch.nn.Sequential(torch.nn.Linear(adim, adim), # torch.nn.Tanh(), # torch.nn.Linear(adim, self.lid_odim)) self.lid_output_layer = torch.nn.Linear(adim, self.lid_odim) # here we hack to use lsm loss, but with lsm_weight ZERO self.lid_criterion = LanguageIDMultitakLoss(self.ignore_id, \ normalize_length=args.transformer_length_normalized_loss) self.lid_mtl_alpha = args.lid_mtl_alpha logging.warning("language id multitask training alpha %f" % (self.lid_mtl_alpha)) self.log_lid_mtl_acc = args.log_lid_mtl_acc # reset parameters self.reset_parameters(args)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) # target matching system organization self.oversampling = args.oversampling self.residual = args.residual self.outer = args.outer self.poster = torch.nn.Linear(args.adim, odim * self.oversampling) if self.outer: if self.residual: self.matcher_res = torch.nn.Linear(idim, odim) self.matcher = torch.nn.Linear(odim, odim) else: self.matcher = torch.nn.Linear(odim + idim, odim) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank