def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, src_vocab_size: int = 0, src_token_list: Union[Tuple[str, ...], List[str]] = [], ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_bleu: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", extract_feats_in_collect_stats: bool = True, share_decoder_input_output_embed: bool = False, share_encoder_decoder_input_embed: bool = False, ): assert check_argument_types() super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.src_vocab_size = src_vocab_size self.ignore_id = ignore_id self.token_list = token_list.copy() if share_decoder_input_output_embed: if decoder.output_layer is not None: decoder.output_layer.weight = decoder.embed[0].weight logging.info( "Decoder input embedding and output linear layer are shared" ) else: logging.warning( "Decoder has no output layer, so it cannot be shared " "with input embedding") if share_encoder_decoder_input_embed: if src_vocab_size == vocab_size: frontend.embed[0].weight = decoder.embed[0].weight logging.info("Encoder and decoder input embeddings are shared") else: logging.warning( f"src_vocab_size ({src_vocab_size}) does not equal tgt_vocab_size" f" ({vocab_size}), so the encoder and decoder input embeddings " "cannot be shared") self.frontend = frontend self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder self.decoder = decoder self.criterion_mt = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) # MT error calculator if report_bleu: self.mt_error_calculator = MTErrorCalculator( token_list, sym_space, sym_blank, report_bleu) else: self.mt_error_calculator = None self.extract_feats_in_collect_stats = extract_feats_in_collect_stats
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__( self, vocab_size: int, token_list: Union[Tuple[str, ...], List[str]], frontend: Optional[AbsFrontend], specaug: Optional[AbsSpecAug], normalize: Optional[AbsNormalize], preencoder: Optional[AbsPreEncoder], encoder: AbsEncoder, postencoder: Optional[AbsPostEncoder], decoder: AbsDecoder, extra_asr_decoder: Optional[AbsDecoder], extra_mt_decoder: Optional[AbsDecoder], ctc: CTC, src_vocab_size: int = 0, src_token_list: Union[Tuple[str, ...], List[str]] = [], asr_weight: float = 0.0, mt_weight: float = 0.0, mtlalpha: float = 0.0, ignore_id: int = -1, lsm_weight: float = 0.0, length_normalized_loss: bool = False, report_cer: bool = True, report_wer: bool = True, report_bleu: bool = True, sym_space: str = "<space>", sym_blank: str = "<blank>", extract_feats_in_collect_stats: bool = True, ): assert check_argument_types() assert 0.0 <= asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" super().__init__() # note that eos is the same as sos (equivalent ID) self.sos = vocab_size - 1 self.eos = vocab_size - 1 self.vocab_size = vocab_size self.src_vocab_size = src_vocab_size self.ignore_id = ignore_id self.asr_weight = asr_weight self.mt_weight = mt_weight self.mtlalpha = mtlalpha self.token_list = token_list.copy() self.frontend = frontend self.specaug = specaug self.normalize = normalize self.preencoder = preencoder self.postencoder = postencoder self.encoder = encoder self.decoder = ( decoder # TODO(jiatong): directly implement multi-decoder structure at here ) self.criterion_st = LabelSmoothingLoss( size=vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) self.criterion_asr = LabelSmoothingLoss( size=src_vocab_size, padding_idx=ignore_id, smoothing=lsm_weight, normalize_length=length_normalized_loss, ) # submodule for ASR task if self.asr_weight > 0: assert ( src_token_list is not None ), "Missing src_token_list, cannot add asr module to st model" if self.mtlalpha > 0.0: self.ctc = ctc if self.mtlalpha < 1.0: self.extra_asr_decoder = extra_asr_decoder elif extra_asr_decoder is not None: logging.warning( "Not using extra_asr_decoder because " "mtlalpha is set as {} (== 1.0)".format(mtlalpha), ) # submodule for MT task if self.mt_weight > 0: self.extra_mt_decoder = extra_mt_decoder elif extra_mt_decoder is not None: logging.warning( "Not using extra_mt_decoder because " "mt_weight is set as {} (== 0)".format(mt_weight), ) # MT error calculator if report_bleu: self.mt_error_calculator = MTErrorCalculator( token_list, sym_space, sym_blank, report_bleu) else: self.mt_error_calculator = None # ASR error calculator if report_cer or report_wer: assert ( src_token_list is not None ), "Missing src_token_list, cannot add asr module to st model" self.asr_error_calculator = ASRErrorCalculator( src_token_list, sym_space, sym_blank, report_cer, report_wer) else: self.asr_error_calculator = None self.extract_feats_in_collect_stats = extract_feats_in_collect_stats