def build(cls, idim: int, odim: int, **kwargs): """Initialize this class with python-level args. Args: idim (int): The number of an input feature dim. odim (int): The number of output vocab. Returns: ASRinterface: A new instance of ASRInterface. """ def wrap(parser): return get_parser(parser, required=False) args = argparse.Namespace(**kwargs) args = fill_missing_args(args, wrap) args = fill_missing_args(args, cls.add_arguments) return cls(idim, odim, args)
def build(cls, n_vocab: int, **kwargs): """Initialize this class with python-level args. Args: idim (int): The number of vocabulary. Returns: LMinterface: A new instance of LMInterface. """ # local import to avoid cyclic import in lm_train from espnet.bin.lm_train import get_parser def wrap(parser): return get_parser(parser, required=False) args = argparse.Namespace(**kwargs) args = fill_missing_args(args, wrap) args = fill_missing_args(args, cls.add_arguments) return cls(n_vocab, args)
def build(cls, target, **kwargs): """Initialize optimizer with python-level args. Args: target: for pytorch `model.parameters()`, for chainer `model` Returns: new Optimizer """ args = argparse.Namespace(**kwargs) args = fill_missing_args(args, cls.add_arguments) return cls.from_args(target, args)
def build(cls, key: str, **kwargs): """Initialize this class with python-level args. Args: key (str): key of hyper parameter Returns: LMinterface: A new instance of LMInterface. """ def add(parser): return cls.add_arguments(key, parser) kwargs = {f"{key}_{cls.alias}_" + k: v for k, v in kwargs.items()} args = argparse.Namespace(**kwargs) args = fill_missing_args(args, add) return cls(key, args)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) odim = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.output = torch.nn.Linear(256, self.odim) # mean + std pooling self.att = Attention(256) self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model)
def __init__( self, idim: int, odim: int, args: Namespace, ignore_id: int = -1, blank_id: int = 0, training: bool = True, ): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) args = fill_missing_args(args, self.add_arguments) self.is_transducer = True self.use_auxiliary_enc_outputs = (True if ( training and args.use_aux_transducer_loss) else False) self.subsample = get_subsample( args, mode="asr", arch="transformer" if args.etype == "custom" else "rnn-t") if self.use_auxiliary_enc_outputs: n_layers = (((len(args.enc_block_arch) * args.enc_block_repeat) - 1) if args.enc_block_arch is not None else (args.elayers - 1)) aux_enc_output_layers = valid_aux_encoder_output_layers( args.aux_transducer_loss_enc_output_layers, n_layers, args.use_symm_kl_div_loss, self.subsample, ) else: aux_enc_output_layers = [] if args.etype == "custom": if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.encoder = CustomEncoder( idim, args.enc_block_arch, input_layer=args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args. custom_enc_positional_encoding_type, positionwise_activation_type=args. custom_enc_pw_activation_type, conv_mod_activation_type=args. custom_enc_conv_mod_activation_type, aux_enc_output_layers=aux_enc_output_layers, ) encoder_out = self.encoder.enc_out else: self.enc = encoder_for( args, idim, self.subsample, aux_enc_output_layers=aux_enc_output_layers, ) encoder_out = args.eprojs if args.dtype == "custom": if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, input_layer=args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args. custom_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = self.decoder.dunits else: self.dec = RNNDecoder( odim, args.dtype, args.dlayers, args.dunits, args.dec_embed_dim, dropout_rate=args.dropout_rate_decoder, dropout_rate_embed=args.dropout_rate_embed_decoder, blank_id=blank_id, ) decoder_out = args.dunits self.transducer_tasks = TransducerTasks( encoder_out, decoder_out, args.joint_dim, odim, joint_activation_type=args.joint_activation_type, transducer_loss_weight=args.transducer_weight, ctc_loss=args.use_ctc_loss, ctc_loss_weight=args.ctc_loss_weight, ctc_loss_dropout_rate=args.ctc_loss_dropout_rate, lm_loss=args.use_lm_loss, lm_loss_weight=args.lm_loss_weight, lm_loss_smoothing_rate=args.lm_loss_smoothing_rate, aux_transducer_loss=args.use_aux_transducer_loss, aux_transducer_loss_weight=args.aux_transducer_loss_weight, aux_transducer_loss_mlp_dim=args.aux_transducer_loss_mlp_dim, aux_trans_loss_mlp_dropout_rate=args. aux_transducer_loss_mlp_dropout_rate, symm_kl_div_loss=args.use_symm_kl_div_loss, symm_kl_div_loss_weight=args.symm_kl_div_loss_weight, fastemit_lambda=args.fastemit_lambda, blank_id=blank_id, ignore_id=ignore_id, training=training, ) if training and (args.report_cer or args.report_wer): self.error_calculator = ErrorCalculator( self.decoder if args.dtype == "custom" else self.dec, self.transducer_tasks.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.default_parameters(args) self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL self.reset_parameters(args) self.adim = args.adim self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) self.rnnlm = None # multilingual MT related self.multilingual = args.multilingual
def __init__(self, idim, odim, args=None): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.use_masking = args.use_masking self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # TODO(kan-bayashi): support reduction_factor > 1 if self.reduction_factor != 1: raise NotImplementedError("Support only reduction_factor = 1.") # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.duration_criterion = DurationPredictorLoss() # TODO(kan-bayashi): support knowledge distillation loss self.criterion = torch.nn.L1Loss()
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.asr_weight = args.asr_weight self.mt_weight = args.mt_weight self.mtlalpha = args.mtlalpha assert 0.0 <= self.asr_weight < 1.0, "asr_weight should be [0.0, 1.0)" assert 0.0 <= self.mt_weight < 1.0, "mt_weight should be [0.0, 1.0)" assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="st", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.enc = encoder_for(args, idim, self.subsample) # attention (ST) self.att = att_for(args) # decoder (ST) self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # submodule for ASR task self.ctc = None self.att_asr = None self.dec_asr = None if self.asr_weight > 0: if self.mtlalpha > 0.0: self.ctc = CTC( odim, args.eprojs, args.dropout_rate, ctc_type=args.ctc_type, reduce=True, ) if self.mtlalpha < 1.0: # attention (asr) self.att_asr = att_for(args) # decoder (asr) args_asr = copy.deepcopy(args) args_asr.atype = "location" # TODO(hirofumi0810): make this option self.dec_asr = decoder_for(args_asr, odim, self.sos, self.eos, self.att_asr, labeldist) # submodule for MT task if self.mt_weight > 0: self.embed_mt = torch.nn.Embedding(odim, args.eunits, padding_idx=self.pad) self.dropout_mt = torch.nn.Dropout(p=args.dropout_rate) self.enc_mt = encoder_for(args, args.eunits, subsample=np.ones(args.elayers + 1, dtype=np.int)) # weight initialization self.init_like_chainer() # options for beam search if self.asr_weight > 0 and args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking)
def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to mask padded part in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.adim = args.adim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.encoder_reduction_factor = args.encoder_reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError( "there is no such an activation function. (%s)" % args.output_activation ) # define network modules self.enc = Encoder( idim=idim * args.encoder_reduction_factor, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) dec_idim = ( args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim ) if args.atype == "location": att = AttLoc( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) elif args.atype == "forward": att = AttForward( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim, ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG( idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units, ) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) if self.src_reconstruction_loss_lambda > 0: self.src_reconstructor = Encoder( idim=dec_idim, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) self.src_reconstructor_linear = torch.nn.Linear( args.econv_chans, idim * args.encoder_reduction_factor ) self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) if self.trg_reconstruction_loss_lambda > 0: self.trg_reconstructor = Encoder( idim=dec_idim, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) self.trg_reconstructor_linear = torch.nn.Linear( args.econv_chans, odim * args.reduction_factor ) self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model)
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.etype == "transformer": self.subsample = get_subsample(args, mode="asr", arch="transformer") self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate_encoder, ) else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for(args, idim, self.subsample) if args.dtype == "transformer": self.decoder = Decoder( odim=odim, jdim=args.joint_dim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_dec_input_layer, dropout_rate=args.dropout_rate_decoder, positional_dropout_rate=args.dropout_rate_decoder, attention_dropout_rate=args.transformer_attn_dropout_rate_decoder, ) else: if args.etype == "transformer": args.eprojs = args.adim if args.rnnt_mode == "rnnt-att": self.att = att_for(args) self.dec = decoder_for(args, odim, self.att) else: self.dec = decoder_for(args, odim) self.etype = args.etype self.dtype = args.dtype self.rnnt_mode = args.rnnt_mode self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.adim = args.adim self.reporter = Reporter() self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculatorTrans if self.dtype == "transformer": self.error_calculator = ErrorCalculatorTrans(self.decoder, args) else: self.error_calculator = ErrorCalculatorTrans(self.dec, args) else: self.error_calculator = None self.logzero = -10000000000.0 self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # gs534 - word vocab bpe = len(self.char_list) > 100 # hack here for bpe flag self.vocabulary = Vocabulary(args.dictfile, bpe) if args.dictfile != '' else None # gs534 - create lexicon tree lextree = None self.meeting_KB = None self.n_KBs = getattr(args, 'dynamicKBs', 0) pretrain_emb = [] if args.meetingKB and args.meetingpath != '': if self.n_KBs == 0 or not os.path.isdir(os.path.join(args.meetingpath, 'split_0')): self.meeting_KB = KBmeeting(self.vocabulary, args.meetingpath, args.char_list, bpe) else: # arrange multiple KBs self.meeting_KB = [] for i in range(self.n_KBs): self.meeting_KB.append(KBmeeting(self.vocabulary, os.path.join(args.meetingpath, 'split_{}'.format(i)), args.char_list, bpe)) # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist( odim, args.lsm_type, transcript=args.train_json ) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for(args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist, meetingKB=self.meeting_KB[0] if isinstance(self.meeting_KB, list) else self.meeting_KB) # weight initialization self.init_from = getattr(args, 'init_full_model', None) self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0, training=True): """Construct an E2E object for transducer model.""" torch.nn.Module.__init__(self) args = fill_missing_args(args, self.add_arguments) self.is_rnnt = True self.transducer_weight = args.transducer_weight self.use_aux_task = (True if (args.aux_task_type is not None and training) else False) self.use_aux_ctc = args.aux_ctc and training self.aux_ctc_weight = args.aux_ctc_weight self.use_aux_cross_entropy = args.aux_cross_entropy and training self.aux_cross_entropy_weight = args.aux_cross_entropy_weight if self.use_aux_task: n_layers = ((len(args.enc_block_arch) * args.enc_block_repeat - 1) if args.enc_block_arch is not None else (args.elayers - 1)) aux_task_layer_list = valid_aux_task_layer_list( args.aux_task_layer_list, n_layers, ) else: aux_task_layer_list = [] if "custom" in args.etype: if args.enc_block_arch is None: raise ValueError( "When specifying custom encoder type, --enc-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.subsample = get_subsample(args, mode="asr", arch="transformer") self.encoder = CustomEncoder( idim, args.enc_block_arch, input_layer=args.custom_enc_input_layer, repeat_block=args.enc_block_repeat, self_attn_type=args.custom_enc_self_attn_type, positional_encoding_type=args. custom_enc_positional_encoding_type, positionwise_activation_type=args. custom_enc_pw_activation_type, conv_mod_activation_type=args. custom_enc_conv_mod_activation_type, aux_task_layer_list=aux_task_layer_list, ) encoder_out = self.encoder.enc_out self.most_dom_list = args.enc_block_arch[:] else: self.subsample = get_subsample(args, mode="asr", arch="rnn-t") self.enc = encoder_for( args, idim, self.subsample, aux_task_layer_list=aux_task_layer_list, ) encoder_out = args.eprojs if "custom" in args.dtype: if args.dec_block_arch is None: raise ValueError( "When specifying custom decoder type, --dec-block-arch" "should also be specified in training config. See" "egs/vivos/asr1/conf/transducer/train_*.yaml for more info." ) self.decoder = CustomDecoder( odim, args.dec_block_arch, input_layer=args.custom_dec_input_layer, repeat_block=args.dec_block_repeat, positionwise_activation_type=args. custom_dec_pw_activation_type, dropout_rate_embed=args.dropout_rate_embed_decoder, ) decoder_out = self.decoder.dunits if "custom" in args.etype: self.most_dom_list += args.dec_block_arch[:] else: self.most_dom_list = args.dec_block_arch[:] else: self.dec = DecoderRNNT( odim, args.dtype, args.dlayers, args.dunits, blank_id, args.dec_embed_dim, args.dropout_rate_decoder, args.dropout_rate_embed_decoder, ) decoder_out = args.dunits self.joint_network = JointNetwork(odim, encoder_out, decoder_out, args.joint_dim, args.joint_activation_type) if hasattr(self, "most_dom_list"): self.most_dom_dim = sorted( Counter(d["d_hidden"] for d in self.most_dom_list if "d_hidden" in d).most_common(), key=lambda x: x[0], reverse=True, )[0][0] self.etype = args.etype self.dtype = args.dtype self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.reporter = Reporter() self.error_calculator = None self.default_parameters(args) if training: self.criterion = TransLoss(args.trans_type, self.blank_id) decoder = self.decoder if self.dtype == "custom" else self.dec if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( decoder, self.joint_network, args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) if self.use_aux_task: self.auxiliary_task = AuxiliaryTask( decoder, self.joint_network, self.criterion, args.aux_task_type, args.aux_task_weight, encoder_out, args.joint_dim, ) if self.use_aux_ctc: self.aux_ctc = ctc_for( Namespace( num_encs=1, eprojs=encoder_out, dropout_rate=args.aux_ctc_dropout_rate, ctc_type="warpctc", ), odim, ) if self.use_aux_cross_entropy: self.aux_decoder_output = torch.nn.Linear(decoder_out, odim) self.aux_cross_entropy = LabelSmoothingLoss( odim, ignore_id, args.aux_cross_entropy_smoothing) self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args=None): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha)
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 self.pad = 0 # NOTE: we reserve index:0 for <pad> although this is reserved for a blank class # in ASR. However, blank labels are not used in MT. # To keep the vocabulary size, # we use index:0 for padding instead of adding one more class. # subsample info self.subsample = get_subsample(args, mode="mt", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None # multilingual related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) # encoder self.embed = torch.nn.Embedding(idim, args.eunits, padding_idx=self.pad) self.dropout = torch.nn.Dropout(p=args.dropout_rate) self.enc = encoder_for(args, args.eunits, self.subsample) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) if args.eunits != args.dunits: raise ValueError( "When using tie_src_tgt_embedding, eunits and dunits must be equal." ) self.embed.weight = self.dec.embed.weight # tie emeddings and the classfier if args.tie_classifier: if args.context_residual: raise ValueError( "When using tie_classifier, context_residual must be turned off." ) self.dec.output.weight = self.dec.embed.weight # weight initialization self.init_like_fairseq() # options for beam search if args.report_bleu: trans_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": 0, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, "tgt_lang": False, } self.trans_args = argparse.Namespace(**trans_args) self.report_bleu = args.report_bleu else: self.report_bleu = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args=None): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError('there is no such an activation function. (%s)' % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder(idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx) dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder(idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor) self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma) if self.use_cbhg: self.cbhg = CBHG(idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__(self, idim, odim, args): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ super(E2E, self).__init__() torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) self.mtlalpha = args.mtlalpha assert 0.0 <= self.mtlalpha <= 1.0, "mtlalpha should be [0.0, 1.0]" self.etype = args.etype self.verbose = args.verbose # NOTE: for self.build method args.char_list = getattr(args, "char_list", None) self.char_list = args.char_list self.outdir = args.outdir self.space = args.sym_space self.blank = args.sym_blank self.reporter = Reporter() # below means the last number becomes eos/sos ID # note that sos/eos IDs are identical self.sos = odim - 1 self.eos = odim - 1 # subsample info self.subsample = get_subsample(args, mode="asr", arch="rnn") # label smoothing info if args.lsm_type and os.path.isfile(args.train_json): logging.info("Use label smoothing with " + args.lsm_type) labeldist = label_smoothing_dist(odim, args.lsm_type, transcript=args.train_json) else: labeldist = None if getattr(args, "use_frontend", False): # use getattr to keep compatibility self.frontend = frontend_for(args, idim) self.feature_transform = feature_transform_for( args, (idim - 1) * 2) idim = args.n_mels else: self.frontend = None # encoder self.enc = encoder_for(args, idim, self.subsample) # ctc self.ctc = ctc_for(args, odim) # attention self.att = att_for(args) # decoder self.dec = decoder_for(args, odim, self.sos, self.eos, self.att, labeldist) # weight initialization self.init_like_chainer() # options for beam search if args.report_cer or args.report_wer: recog_args = { "beam_size": args.beam_size, "penalty": args.penalty, "ctc_weight": args.ctc_weight, "maxlenratio": args.maxlenratio, "minlenratio": args.minlenratio, "lm_weight": args.lm_weight, "rnnlm": args.rnnlm, "nbest": args.nbest, "space": args.sym_space, "blank": args.sym_blank, } self.recog_args = argparse.Namespace(**recog_args) self.report_cer = args.report_cer self.report_wer = args.report_wer else: self.report_cer = False self.report_wer = False self.rnnlm = None self.logzero = -10000000000.0 self.loss = None self.acc = None
def __init__(self, idim, odim, args=None): """Initialize TTS-Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - embed_dim (int): Dimension of character embedding. - eprenet_conv_layers (int): Number of encoder prenet convolution layers. - eprenet_conv_chans (int): Number of encoder prenet convolution channels. - eprenet_conv_filts (int): Filter size of encoder prenet convolution. - dprenet_layers (int): Number of decoder prenet layers. - dprenet_units (int): Number of decoder prenet hidden units. - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - postnet_layers (int): Number of postnet layers. - postnet_chans (int): Number of postnet channels. - postnet_filts (int): Filter size of postnet. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - use_batch_norm (bool): Whether to use batch normalization in encoder prenet. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - eprenet_dropout_rate (float): Dropout rate in encoder prenet. - dprenet_dropout_rate (float): Dropout rate in decoder prenet. - postnet_dropout_rate (float): Dropout rate in postnet. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). - loss_type (str): How to calculate loss. - use_guided_attn_loss (bool): Whether to use guided attention loss. - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. - modules_applied_guided_attn (list): List of module names to apply guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lambda (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model)
def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError("there is no such an activation function. (%s)" % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, resume=args.encoder_resume, ) dec_idim = (args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) reduction = 'none' if args.use_weighted_masking else 'mean' self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim = 1 # pitch prediction pitch_predictor_layers = 2 pitch_predictor_chans = 384 pitch_predictor_kernel_size = 3 pitch_predictor_dropout_rate = 0.5 pitch_embed_kernel_size = 9 pitch_embed_dropout_rate = 0.5 self.stop_gradient_from_pitch_predictor = False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers = 2 energy_predictor_chans = 384 energy_predictor_kernel_size = 3 energy_predictor_dropout_rate = 0.5 energy_embed_kernel_size = 9 energy_embed_dropout_rate = 0.5 self.stop_gradient_from_energy_predictor = False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # define criterions self.prosody_criterion = prosody_criterions( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n')
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args.transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args.transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.decoder_mode = args.decoder_mode if self.decoder_mode == "maskctc": self.mask_token = odim - 1 self.sos = odim - 2 self.eos = odim - 2 else: self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC( odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True ) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimension. - spkidloss_weight (float): Weight for the speaker id module when combined to tts loss - num_spk (int): Number of speakers in the training data - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to mask padded part in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.spkidloss_weight = args.spkidloss_weight self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError('there is no such an activation function. (%s)' % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder(idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx) if args.train_spkid_extractor: self.train_spkid_extractor = True self.resnet_spkid = E2E_speakerid(input_dim=odim, output_dim=args.num_spk, Q=odim - 1, D=32, hidden_dim=args.spk_embed_dim, pooling='mean', network_type='lde', distance_type='sqr', asoftmax=True, resnet_AvgPool2d_fre_ksize=10) self.angle_loss = AngleLoss() else: self.train_spkid_extractor = False dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "noatt": # This condition is satisfied only when using phone alignment for TTS input (Currently when using phn. ali. in TTS training for learning speaker embedding) att = None else: raise NotImplementedError("Support only location or forward") self.dec = Decoder(idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor) self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG(idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)