class TransformerLM(nn.Module, LMInterface, BatchScorerInterface): """Transformer language model.""" @staticmethod def add_arguments(parser): """Add arguments to command line argument parser.""" parser.add_argument('--layer', type=int, default=4, help='Number of hidden layers') parser.add_argument('--unit', type=int, default=1024, help='Number of hidden units in feedforward layer') parser.add_argument('--att-unit', type=int, default=256, help='Number of hidden units in attention layer') parser.add_argument('--embed-unit', type=int, default=128, help='Number of hidden units in embedding layer') parser.add_argument('--head', type=int, default=2, help='Number of multi head attention') parser.add_argument('--dropout-rate', type=float, default=0.5, help='dropout probability') parser.add_argument('--pos-enc', default="sinusoidal", choices=["sinusoidal", "none"], help='positional encoding') return parser def __init__(self, n_vocab, args): """Initialize class. Args: n_vocab (int): The size of the vocabulary args (argparse.Namespace): configurations. see py:method:`add_arguments` """ nn.Module.__init__(self) if args.pos_enc == "sinusoidal": pos_enc_class = PositionalEncoding elif args.pos_enc == "none": def pos_enc_class(*args, **kwargs): return nn.Sequential() # indentity else: raise ValueError(f"unknown pos-enc option: {args.pos_enc}") self.embed = nn.Embedding(n_vocab, args.embed_unit) self.encoder = Encoder(idim=args.embed_unit, attention_dim=args.att_unit, attention_heads=args.head, linear_units=args.unit, num_blocks=args.layer, dropout_rate=args.dropout_rate, input_layer="linear", pos_enc_class=pos_enc_class) self.decoder = nn.Linear(args.att_unit, n_vocab) def _target_mask(self, ys_in_pad): ys_mask = ys_in_pad != 0 m = subsequent_mask(ys_mask.size(-1), device=ys_mask.device).unsqueeze(0) return ys_mask.unsqueeze(-2) & m def forward( self, x: torch.Tensor, t: torch.Tensor ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Compute LM loss value from buffer sequences. Args: x (torch.Tensor): Input ids. (batch, len) t (torch.Tensor): Target ids. (batch, len) Returns: tuple[torch.Tensor, torch.Tensor, torch.Tensor]: Tuple of loss to backward (scalar), negative log-likelihood of t: -log p(t) (scalar) and the number of elements in x (scalar) Notes: The last two return values are used in perplexity: p(t)^{-n} = exp(-log p(t) / n) """ xm = (x != 0) h, _ = self.encoder(self.embed(x), self._target_mask(x)) y = self.decoder(h) loss = F.cross_entropy(y.view(-1, y.shape[-1]), t.view(-1), reduction="none") mask = xm.to(dtype=loss.dtype) logp = loss * mask.view(-1) logp = logp.sum() count = mask.sum() return logp / count, logp, count def score(self, y: torch.Tensor, state: Any, x: torch.Tensor) -> Tuple[torch.Tensor, Any]: """Score new token. Args: y (torch.Tensor): 1D torch.int64 prefix tokens. state: Scorer state for prefix tokens x (torch.Tensor): encoder feature that generates ys. Returns: tuple[torch.Tensor, Any]: Tuple of torch.float32 scores for next token (n_vocab) and next state for ys """ y = y.unsqueeze(0) h, _, cache = self.encoder.forward_one_step(self.embed(y), self._target_mask(y), cache=state) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1).squeeze(0) return logp, cache # batch beam search API (see BatchScorerInterface) def batch_score(self, ys: torch.Tensor, states: List[Any], xs: torch.Tensor) -> Tuple[torch.Tensor, List[Any]]: """Score new token batch (required). Args: ys (torch.Tensor): torch.int64 prefix tokens (n_batch, ylen). states (List[Any]): Scorer states for prefix tokens. xs (torch.Tensor): The encoder feature that generates ys (n_batch, xlen, n_feat). Returns: tuple[torch.Tensor, List[Any]]: Tuple of batchfied scores for next token with shape of `(n_batch, n_vocab)` and next state list for ys. """ # merge states n_batch = len(ys) n_layers = len(self.encoder.encoders) if states[0] is None: batch_state = None else: # transpose state of [batch, layer] into [layer, batch] batch_state = [ torch.stack([states[b][l] for b in range(n_batch)]) for l in range(n_layers) ] # batch decoding h, _, states = self.encoder.forward_one_step(self.embed(ys), self._target_mask(ys), cache=batch_state) h = self.decoder(h[:, -1]) logp = h.log_softmax(dim=-1) # transpose state of [layer, batch] into [batch, layer] state_list = [[states[l][b] for l in range(n_layers)] for b in range(n_batch)] return logp, state_list
def __init__(self, idim, odim, args, device, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility #args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.criterion = LabelSmoothingLoss( odim, ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) else: self.decoder = None self.criterion = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() self.reset_parameters(args) self.adim = args.adim # used for CTC (equal to d_model) self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.device = device
def __init__(self, num_time_mask=2, num_freq_mask=2, freq_mask_length=15, time_mask_length=15, feature_dim=320, model_size=512, feed_forward_size=1024, hidden_size=64, dropout=0.1, num_head=8, num_encoder_layer=6, num_decoder_layer=6, vocab_path='testing_vocab.model', max_feature_length=1024, max_token_length=50, enable_spec_augment=True, share_weight=True, smoothing=0.1, restrict_left_length=20, restrict_right_length=20, mtlalpha=0.2, report_wer=True): super(Transformer, self).__init__() self.enable_spec_augment = enable_spec_augment self.max_token_length = max_token_length self.restrict_left_length = restrict_left_length self.restrict_right_length = restrict_right_length self.vocab = Vocab(vocab_path) self.sos = self.vocab.bos_id self.eos = self.vocab.eos_id self.adim = model_size self.odim = self.vocab.vocab_size self.ignore_id = self.vocab.pad_id if enable_spec_augment: self.spec_augment = SpecAugment( num_time_mask=num_time_mask, num_freq_mask=num_freq_mask, freq_mask_length=freq_mask_length, time_mask_length=time_mask_length, max_sequence_length=max_feature_length) self.encoder = Encoder(idim=feature_dim, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_encoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, attention_dropout_rate=dropout, input_layer='linear', padding_idx=self.vocab.pad_id) self.decoder = Decoder(odim=self.vocab.vocab_size, attention_dim=model_size, attention_heads=num_head, linear_units=feed_forward_size, num_blocks=num_decoder_layer, dropout_rate=dropout, positional_dropout_rate=dropout, self_attention_dropout_rate=dropout, src_attention_dropout_rate=0, input_layer='embed', use_output_layer=False) self.decoder_linear = t.nn.Linear(model_size, self.vocab.vocab_size, bias=True) self.decoder_switch_linear = t.nn.Linear(model_size, 4, bias=True) self.criterion = LabelSmoothingLoss(size=self.odim, smoothing=smoothing, padding_idx=self.vocab.pad_id, normalize_length=True) self.switch_criterion = LabelSmoothingLoss( size=4, smoothing=0, padding_idx=self.vocab.pad_id, normalize_length=True) self.mtlalpha = mtlalpha if mtlalpha > 0.0: self.ctc = CTC(self.odim, eprojs=self.adim, dropout_rate=dropout, ctc_type='builtin', reduce=False) else: self.ctc = None if report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator def load_token_list(path=vocab_path.replace('.model', '.vocab')): with open(path) as reader: data = reader.readlines() data = [i.split('\t')[0] for i in data] return data self.char_list = load_token_list() self.error_calculator = ErrorCalculator( char_list=self.char_list, sym_space=' ', sym_blank=self.vocab.blank_token, report_wer=True) else: self.error_calculator = None self.rnnlm = None self.reporter = Reporter() self.switch_loss = LabelSmoothingLoss(size=4, smoothing=0, padding_idx=0) print('initing') initialize(self, init_type='xavier_normal') print('inited')
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, is_spk_layer_norm: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, hparams=None, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim self.hparams = hparams if self.hparams.is_multi_speakers: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder # print(idim) encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if self.hparams.is_multi_speakers: self.speaker_embedding = torch.nn.Embedding( hparams.n_speakers, self.spk_embed_dim) std = sqrt(2.0 / (hparams.n_speakers + self.spk_embed_dim)) val = sqrt(3.0) * std # uniform bounds for std self.speaker_embedding.weight.data.uniform_(-val, val) self.spkemb_projection = torch.nn.Linear(hparams.spk_embed_dim, hparams.spk_embed_dim) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, hparams=hparams) if self.hparams.style_embed_integration_type == "concat": self.gst_projection = torch.nn.Linear(adim + adim, adim) # define additional projection for speaker embedding if self.hparams.is_multi_speakers: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, hparams=hparams) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def __init__(self, idim, odim, args=None): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet(idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx), torch.nn.Linear(args.eprenet_conv_chans, args.adim)) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet(idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate), torch.nn.Linear(args.dprenet_units, args.adim)) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # define loss function self.criterion = TransformerLoss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha)
odim = 5 model = "decoder" if model == "decoder": decoder = Decoder( odim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, ) decoder.eval() else: encoder = Encoder( idim=odim, attention_dim=adim, linear_units=3, num_blocks=2, dropout_rate=0.0, input_layer="embed", ) encoder.eval() xlen = 100 xs = torch.randint(0, odim, (1, xlen)) memory = torch.randn(2, 500, adim) mask = subsequent_mask(xlen).unsqueeze(0) result = {"cached": [], "baseline": []} n_avg = 10 for key, value in result.items(): cache = None print(key)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim if args.report_bleu: from espnet.nets.e2e_mt_common import ErrorCalculator self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.report_bleu ) else: self.error_calculator = None self.rnnlm = None # multilingual NMT related self.multilingual = args.multilingual
class FeedForwardTransformer(TTSInterface, torch.nn.Module): """Feed Forward Transformer for TTS a.k.a. FastSpeech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group( "feed-forward transformer model setting") # network structure related group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions") group.add_argument("--aheads", default=4, type=int, help="Number of heads for multi head attention") group.add_argument("--elayers", default=6, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=1536, type=int, help="Number of encoder hidden units") group.add_argument("--dlayers", default=6, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1536, type=int, help="Number of decoder hidden units") group.add_argument("--positionwise-layer-type", default="linear", type=str, choices=["linear", "conv1d", "conv1d-linear"], help="Positionwise layer type.") group.add_argument("--positionwise-conv-kernel-size", default=3, type=int, help="Kernel size of positionwise conv1d layer") group.add_argument("--postnet-layers", default=0, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=256, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument("--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help= "Use trainable scaled positional encoding instead of the fixed scale one" ) group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block") group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block") group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in encoder" ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in decoder" ) group.add_argument("--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor") group.add_argument("--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor") group.add_argument("--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor") group.add_argument("--teacher-model", default=None, type=str, nargs="?", help="Teacher model file path") group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument("--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions") group.add_argument("--spk-embed-integration-type", type=str, default="add", choices=["add", "concat"], help="How to integrate speaker embedding") # training related group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help="How to initialize transformer parameters") group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding") group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding") group.add_argument("--transformer-lr", default=1.0, type=float, help="Initial value of learning rate") group.add_argument("--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps") group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention") group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding") group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention") group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help= "Dropout rate for transformer decoder except for attention and pos encoding" ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding") group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention") group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention") group.add_argument("--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor") group.add_argument("--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet") group.add_argument("--transfer-encoder-from-teacher", default=True, type=strtobool, help="Whether to transfer teacher's parameters") group.add_argument( "--transferred-encoder-module", default="all", type=str, choices=["all", "embed"], help="Encoder modeules to be trasferred from teacher") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss") group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss") return parser def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) def _forward(self, xs, ilens, ys=None, olens=None, spembs=None, ds=None, is_inference=False): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) else: if ds is None: with torch.no_grad(): ds = self.duration_calculator(xs, ilens, ys, olens, spembs) # (B, Tmax) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) if is_inference: return before_outs, after_outs, d_outs else: return before_outs, after_outs, ds, d_outs def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] if extras is not None: extras = extras[:, :max(ilens)].squeeze(-1) # forward propagation before_outs, after_outs, ds, d_outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: l1_loss, duration_loss = self.criterion(None, before_outs, d_outs, ys, ds, ilens, olens) else: l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens) loss = l1_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "duration_loss": duration_loss.item() }, { "loss": loss.item() }, ] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] if extras is not None: extras = extras[:, :max(ilens)].squeeze(-1) # forward propagation outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)[1] att_ws_dict = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["predicted_fbank"] = [ m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): Dummy for compatibility. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) else: spembs = None # inference _, outs, _ = self._forward(xs, ilens, spembs=spembs, is_inference=True) # (1, L, odim) return outs[0], None, None def _integrate_with_spk_embed(self, hs, spembs): """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _load_teacher_model(self, model_path): # get teacher model config idim, odim, args = get_model_conf(model_path) # assert dimension is the same between teacher and studnet assert idim == self.idim assert odim == self.odim assert args.reduction_factor == self.reduction_factor # load teacher model from espnet.utils.dynamic_import import dynamic_import model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) torch_load(model_path, model) # freeze teacher model parameters for p in model.parameters(): p.requires_grad = False return model def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _transfer_from_teacher(self, transferred_encoder_module): if transferred_encoder_module == "all": for (n1, p1), (n2, p2) in zip(self.encoder.named_parameters(), self.teacher.encoder.named_parameters()): assert n1 == n2, "It seems that encoder structure is different." assert p1.shape == p2.shape, "It seems that encoder size is different." p1.data.copy_(p2.data) elif transferred_encoder_module == "embed": student_shape = self.encoder.embed[0].weight.data.shape teacher_shape = self.teacher.encoder.embed[0].weight.data.shape assert student_shape == teacher_shape, "It seems that embed dimension is different." self.encoder.embed[0].weight.data.copy_( self.teacher.encoder.embed[0].weight.data) else: raise NotImplementedError("Support only all or embed.") @property def attention_plot_class(self): """Return plot class for attention weight plot.""" return TTSPlot @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "duration_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] return plot_keys
def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, dprenet_layers: int = 2, dprenet_units: int = 256, elayers: int = 6, eunits: int = 1024, adim: int = 512, aheads: int = 4, dlayers: int = 6, dunits: int = 1024, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, transformer_enc_dec_attn_dropout_rate: float = 0.1, eprenet_dropout_rate: float = 0.5, dprenet_dropout_rate: float = 0.5, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1", use_guided_attn_loss: bool = True, num_heads_applied_guided_attn: int = 2, num_layers_applied_guided_attn: int = 2, modules_applied_guided_attn: List[str] = ["encoder-decoder"], guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Transformer module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_type = loss_type self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), torch.nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # define loss function self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_enc_alpha, )
def __init__(self, idim, odim, args, ignore_id=-1): torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.char_list = args.char_list self.sampling = 'multinomial' self.reporter = Reporter() self.ctxt = args.ctxt self.duration = int(args.duration) # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="mt", arch="transformer") self.reporter = Reporter() # tie source and target emeddings if args.tie_src_tgt_embedding: if idim != odim: raise ValueError( "When using tie_src_tgt_embedding, idim and odim must be equal." ) self.encoder.embed[0].weight = self.decoder.embed[0].weight # tie emeddings and the classfier if args.tie_classifier: self.decoder.output_layer.weight = self.decoder.embed[0].weight self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) self.normalize_length = args.transformer_length_normalized_loss # for PPL self.reset_parameters(args) self.adim = args.adim self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) self.rnnlm = None # multilingual MT related self.multilingual = args.multilingual
class E2E(torch.nn.Module): """E2E module. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ @staticmethod def add_arguments(parser): """Add arguments.""" group = parser.add_argument_group("transformer model setting") group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help='how to initialize transformer parameters') group.add_argument("--transformer-input-layer", type=str, default="conv2d", choices=["conv2d", "linear", "embed", "custom"], help='transformer input layer type') group.add_argument("--transformer-output-layer", type=str, default='embed', choices=['conv', 'embed', 'linear']) group.add_argument( '--transformer-attn-dropout-rate', default=None, type=float, help= 'dropout in transformer attention. use --dropout-rate if None is set' ) group.add_argument('--transformer-lr', default=10.0, type=float, help='Initial value of learning rate') group.add_argument('--transformer-warmup-steps', default=25000, type=int, help='optimizer warmup steps') group.add_argument('--transformer-length-normalized-loss', default=True, type=strtobool, help='normalize loss by length') group.add_argument('--dropout-rate', default=0.0, type=float, help='Dropout rate for the encoder') # Encoder group.add_argument( '--elayers', default=4, type=int, help= 'Number of encoder layers (for shared recognition part in multi-speaker asr mode)' ) group.add_argument('--eunits', '-u', default=300, type=int, help='Number of encoder hidden units') # Attention group.add_argument( '--adim', default=320, type=int, help='Number of attention transformation dimensions') group.add_argument('--aheads', default=4, type=int, help='Number of heads for multi head attention') # Decoder group.add_argument('--dlayers', default=1, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=320, type=int, help='Number of decoder hidden units') # Streaming params group.add_argument( '--chunk', default=True, type=strtobool, help= 'streaming mode, set True for chunk-encoder, False for look-ahead encoder' ) group.add_argument('--chunk-size', default=16, type=int, help='chunk size for chunk-based encoder') group.add_argument( '--left-window', default=1000, type=int, help='left window size for look-ahead based encoder') group.add_argument( '--right-window', default=1000, type=int, help='right window size for look-ahead based encoder') group.add_argument( '--dec-left-window', default=0, type=int, help='left window size for decoder (look-ahead based method)') group.add_argument( '--dec-right-window', default=6, type=int, help='right window size for decoder (look-ahead based method)') return parser def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_output_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None self.rnnlm = None self.left_window = args.dec_left_window self.right_window = args.dec_right_window def reset_parameters(self, args): """Initialize parameters.""" # initialize parameters initialize(self, args.transformer_init) def forward(self, xs_pad, ilens, ys_pad, enc_mask=None, dec_mask=None): """E2E forward. :param torch.Tensor xs_pad: batch of padded source sequences (B, Tmax, idim) :param torch.Tensor ilens: batch of lengths of source sequences (B) :param torch.Tensor ys_pad: batch of padded target sequences (B, Lmax) :return: ctc loass value :rtype: torch.Tensor :return: attention loss value :rtype: torch.Tensor :return: accuracy in attention decoder :rtype: float """ # 1. forward encoder xs_pad = xs_pad[:, :max(ilens)] # for data parallel batch_size = xs_pad.shape[0] src_mask = make_non_pad_mask(ilens.tolist()).to( xs_pad.device).unsqueeze(-2) if isinstance(self.encoder.embed, EncoderConv2d): xs, hs_mask = self.encoder.embed(xs_pad, torch.sum(src_mask, 2).squeeze()) hs_mask = hs_mask.unsqueeze(1) else: xs, hs_mask = self.encoder.embed(xs_pad, src_mask) if enc_mask is not None: enc_mask = enc_mask[:, :hs_mask.shape[2], :hs_mask.shape[2]] enc_mask = enc_mask & hs_mask if enc_mask is not None else hs_mask hs_pad, _ = self.encoder.encoders(xs, enc_mask) if self.encoder.normalize_before: hs_pad = self.encoder.after_norm(hs_pad) # CTC forward ys = [y[y != self.ignore_id] for y in ys_pad] y_len = max([len(y) for y in ys]) ys_pad = ys_pad[:, :y_len] if dec_mask is not None: dec_mask = dec_mask[:, :y_len + 1, :hs_pad.shape[1]] self.hs_pad = hs_pad batch_size = xs_pad.size(0) if self.mtlalpha == 0.0: loss_ctc = None else: batch_size = xs_pad.size(0) hs_len = hs_mask.view(batch_size, -1).sum(1) loss_ctc = self.ctc(hs_pad.view(batch_size, -1, self.adim), hs_len, ys_pad) # trigger mask hs_mask = hs_mask & dec_mask if dec_mask is not None else hs_mask # 2. forward decoder ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) ys_mask = target_mask(ys_in_pad, self.ignore_id) pred_pad, pred_mask = self.decoder(ys_in_pad, ys_mask, hs_pad, hs_mask) self.pred_pad = pred_pad # 3. compute attention loss loss_att = self.criterion(pred_pad, ys_out_pad) self.acc = th_accuracy(pred_pad.view(-1, self.odim), ys_out_pad, ignore_label=self.ignore_id) # copyied from e2e_asr alpha = self.mtlalpha if alpha == 0: self.loss = loss_att loss_att_data = float(loss_att) loss_ctc_data = None elif alpha == 1: self.loss = loss_ctc loss_att_data = None loss_ctc_data = float(loss_ctc) else: self.loss = alpha * loss_ctc + (1 - alpha) * loss_att loss_att_data = float(loss_att) loss_ctc_data = float(loss_ctc) return self.loss, loss_ctc_data, loss_att_data, self.acc def scorers(self): """Scorers.""" return dict(decoder=self.decoder, ctc=CTCPrefixScorer(self.ctc, self.eos)) def encode(self, x, mask=None): """Encode acoustic features. :param ndarray x: source acoustic feature (T, D) :return: encoder outputs :rtype: torch.Tensor """ self.eval() x = torch.as_tensor(x).unsqueeze(0).cuda() if mask is not None: mask = mask.cuda() if isinstance(self.encoder.embed, EncoderConv2d): hs, _ = self.encoder.embed( x, torch.Tensor([float(x.shape[1])]).cuda()) else: hs, _ = self.encoder.embed(x, None) hs, _ = self.encoder.encoders(hs, mask) if self.encoder.normalize_before: hs = self.encoder.after_norm(hs) return hs.squeeze(0) def viterbi_decode(self, x, y, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.ctc_lo(enc_output).detach().data logit = np.array(logits.cpu().data).T align = viterbi_align(logit, y)[0] return align def ctc_decode(self, x, mask=None): enc_output = self.encode(x, mask) logits = self.ctc.argmax(enc_output.view(1, -1, 512)).detach().data path = np.array(logits.cpu()[0]) return path def recognize(self, x, recog_args, char_list=None, rnnlm=None, use_jit=False): """Recognize input speech. :param ndnarray x: input acoustic feature (B, T, D) or (T, D) :param Namespace recog_args: argment Namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list """ enc_output = self.encode(x).unsqueeze(0) if recog_args.ctc_weight > 0.0: lpz = self.ctc.log_softmax(enc_output) lpz = lpz.squeeze(0) else: lpz = None h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) # initialize hypothesis if rnnlm: hyp = {'score': 0.0, 'yseq': [y], 'rnnlm_prev': None} else: hyp = {'score': 0.0, 'yseq': [y]} if lpz is not None: import numpy from espnet.nets.ctc_prefix_score import CTCPrefixScore ctc_prefix_score = CTCPrefixScore(lpz.cpu().detach().numpy(), 0, self.eos, numpy) hyp['ctc_state_prev'] = ctc_prefix_score.initial_state() hyp['ctc_score_prev'] = 0.0 if ctc_weight != 1.0: # pre-pruning based on attention scores ctc_beam = min(lpz.shape[-1], int(beam * CTC_SCORING_RATIO)) else: ctc_beam = lpz.shape[-1] hyps = [hyp] ended_hyps = [] import six traced_decoder = None for i in six.moves.range(maxlen): logging.debug('position ' + str(i)) hyps_best_kept = [] for hyp in hyps: vy.unsqueeze(1) vy[0] = hyp['yseq'][i] # get nbest local scores and their ids ys_mask = subsequent_mask(i + 1).unsqueeze(0).cuda() ys = torch.tensor(hyp['yseq']).unsqueeze(0).cuda() # FIXME: jit does not match non-jit result if use_jit: if traced_decoder is None: traced_decoder = torch.jit.trace( self.decoder.forward_one_step, (ys, ys_mask, enc_output)) local_att_scores = traced_decoder(ys, ys_mask, enc_output)[0] else: local_att_scores = self.decoder.forward_one_step( ys, ys_mask, enc_output)[0] if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) local_scores = local_att_scores + recog_args.lm_weight * local_lm_scores else: local_scores = local_att_scores if lpz is not None: local_best_scores, local_best_ids = torch.topk( local_att_scores, ctc_beam, dim=1) ctc_scores, ctc_states = ctc_prefix_score( hyp['yseq'], local_best_ids[0].cpu(), hyp['ctc_state_prev']) local_scores = \ (1.0 - ctc_weight) * local_att_scores[:, local_best_ids[0]].cpu() \ + ctc_weight * torch.from_numpy(ctc_scores - hyp['ctc_score_prev']) if rnnlm: local_scores += recog_args.lm_weight * local_lm_scores[:, local_best_ids[ 0]].cpu() local_best_scores, joint_best_ids = torch.topk( local_scores, beam, dim=1) local_best_ids = local_best_ids[:, joint_best_ids[0]] else: local_best_scores, local_best_ids = torch.topk( local_scores, beam, dim=1) for j in six.moves.range(beam): new_hyp = {} new_hyp['score'] = hyp['score'] + float( local_best_scores[0, j]) new_hyp['yseq'] = [0] * (1 + len(hyp['yseq'])) new_hyp['yseq'][:len(hyp['yseq'])] = hyp['yseq'] new_hyp['yseq'][len(hyp['yseq'])] = int(local_best_ids[0, j]) if rnnlm: new_hyp['rnnlm_prev'] = rnnlm_state if lpz is not None: new_hyp['ctc_state_prev'] = ctc_states[joint_best_ids[ 0, j]] new_hyp['ctc_score_prev'] = ctc_scores[joint_best_ids[ 0, j]] # will be (2 x beam) hyps at most hyps_best_kept.append(new_hyp) hyps_best_kept = sorted(hyps_best_kept, key=lambda x: x['score'], reverse=True)[:beam] # sort and get nbest hyps = hyps_best_kept logging.debug('number of pruned hypothes: ' + str(len(hyps))) if char_list is not None: logging.debug( 'best hypo: ' + ''.join([char_list[int(x)] for x in hyps[0]['yseq'][1:]])) # add eos in the final loop to avoid that there are no ended hyps if i == maxlen - 1: logging.info('adding <eos> in the last postion in the loop') for hyp in hyps: hyp['yseq'].append(self.eos) # add ended hypothes to a final list, and removed them from current hypothes # (this will be a probmlem, number of hyps < beam) remained_hyps = [] for hyp in hyps: if hyp['yseq'][-1] == self.eos: # only store the sequence that has more than minlen outputs # also add penalty if len(hyp['yseq']) > minlen: hyp['score'] += (i + 1) * penalty if rnnlm: # Word LM needs to add final <eos> score hyp['score'] += recog_args.lm_weight * rnnlm.final( hyp['rnnlm_prev']) ended_hyps.append(hyp) else: remained_hyps.append(hyp) # end detection from espnet.nets.e2e_asr_common import end_detect if end_detect(ended_hyps, i) and recog_args.maxlenratio == 0.0: logging.info('end detected at %d', i) break hyps = remained_hyps if len(hyps) > 0: logging.debug('remeined hypothes: ' + str(len(hyps))) else: logging.info('no hypothesis. Finish decoding.') break if char_list is not None: for hyp in hyps: logging.debug( 'hypo: ' + ''.join([char_list[int(x)] for x in hyp['yseq'][1:]])) logging.debug('number of ended hypothes: ' + str(len(ended_hyps))) nbest_hyps = sorted( ended_hyps, key=lambda x: x['score'], reverse=True)[:min(len(ended_hyps), recog_args.nbest)] # check number of hypotheis if len(nbest_hyps) == 0: logging.warning( 'there is no N-best results, perform recognition again with smaller minlenratio.' ) # should copy becasuse Namespace will be overwritten globally recog_args = Namespace(**vars(recog_args)) recog_args.minlenratio = max(0.0, recog_args.minlenratio - 0.1) return self.recognize(x, recog_args, char_list, rnnlm) logging.info('total log probability: ' + str(nbest_hyps[0]['score'])) logging.info('normalized log probability: ' + str(nbest_hyps[0]['score'] / len(nbest_hyps[0]['yseq']))) return nbest_hyps def prefix_recognize(self, x, recog_args, train_args, char_list=None, rnnlm=None): '''recognize feat :param ndnarray x: input acouctic feature (B, T, D) or (T, D) :param namespace recog_args: argment namespace contraining options :param list char_list: list of characters :param torch.nn.Module rnnlm: language model module :return: N-best decoding results :rtype: list TODO(karita): do not recompute previous attention for faster decoding ''' pad_len = self.eos - len(char_list) + 1 for i in range(pad_len): char_list.append('<eos>') if isinstance(self.encoder.embed, EncoderConv2d): seq_len = ((x.shape[0] + 1) // 2 + 1) // 2 else: seq_len = ((x.shape[0] - 1) // 2 - 1) // 2 if train_args.chunk: s = np.arange(0, seq_len, train_args.chunk_size) mask = adaptive_enc_mask(seq_len, s).unsqueeze(0) else: mask = turncated_mask(1, seq_len, train_args.left_window, train_args.right_window) enc_output = self.encode(x, mask).unsqueeze(0) lpz = torch.nn.functional.softmax(self.ctc.ctc_lo(enc_output), dim=-1) lpz = lpz.squeeze(0) h = enc_output.squeeze(0) logging.info('input lengths: ' + str(h.size(0))) h_len = h.size(0) # search parms beam = recog_args.beam_size penalty = recog_args.penalty ctc_weight = recog_args.ctc_weight # preprare sos y = self.sos vy = h.new_zeros(1).long() if recog_args.maxlenratio == 0: maxlen = h.shape[0] else: # maxlen >= 1 maxlen = max(1, int(recog_args.maxlenratio * h.size(0))) minlen = int(recog_args.minlenratio * h.size(0)) hyp = { 'score': 0.0, 'yseq': [y], 'rnnlm_prev': None, 'seq': char_list[y], 'last_time': [], "ctc_score": 0.0, "rnnlm_score": 0.0, "att_score": 0.0, "cache": None, "precache": None, "preatt_score": 0.0, "prev_score": 0.0 } hyps = {char_list[y]: hyp} hyps_att = {char_list[y]: hyp} Pb_prev, Pnb_prev = Counter(), Counter() Pb, Pnb = Counter(), Counter() Pjoint = Counter() lpz = lpz.cpu().detach().numpy() vocab_size = lpz.shape[1] r = np.ndarray((vocab_size), dtype=np.float32) l = char_list[y] Pb_prev[l] = 1 Pnb_prev[l] = 0 A_prev = [l] A_prev_id = [[y]] vy.unsqueeze(1) total_copy = time.time() - time.time() samelen = 0 hat_att = {} if mask is not None: chunk_pos = set(np.array(mask.sum(dim=-1))[0]) for i in chunk_pos: hat_att[i] = {} else: hat_att[enc_output.shape[1]] = {} for i in range(h_len): hyps_ctc = {} threshold = recog_args.threshold #self.threshold #np.percentile(r, 98) pos_ctc = np.where(lpz[i] > threshold)[0] #self.removeIlegal(hyps) if mask is not None: chunk_index = mask[0][i].sum().item() else: chunk_index = h_len hyps_res = {} for l, hyp in hyps.items(): if l in hat_att[chunk_index]: hyp['tmp_cache'] = hat_att[chunk_index][l]['cache'] hyp['tmp_att'] = hat_att[chunk_index][l]['att_scores'] else: hyps_res[l] = hyp tmp = self.clusterbyLength( hyps_res ) # This step clusters hyps according to length dict:{length,hyps} start = time.time() # pre-compute beam self.compute_hyps(tmp, i, h_len, enc_output, hat_att[chunk_index], mask, train_args.chunk) total_copy += time.time() - start # Assign score and tokens to hyps #print(hyps.keys()) for l, hyp in hyps.items(): if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] local_best_scores, local_best_ids = torch.topk( local_att_scores, 5, dim=1) pos_att = np.array(local_best_ids[0].cpu()) pos = np.union1d(pos_ctc, pos_att) hyp['pos'] = pos # pre-compute ctc beam hyps_ctc_compute = self.get_ctchyps2compute(hyps, hyps_ctc, i) hyps_res2 = {} for l, hyp in hyps_ctc_compute.items(): l_minus = ' '.join(l.split()[:-1]) if l_minus in hat_att[chunk_index]: hyp['tmp_cur_new_cache'] = hat_att[chunk_index][l_minus][ 'cache'] hyp['tmp_cur_att_scores'] = hat_att[chunk_index][l_minus][ 'att_scores'] else: hyps_res2[l] = hyp tmp2_cluster = self.clusterbyLength(hyps_res2) self.compute_hyps_ctc(tmp2_cluster, h_len, enc_output, hat_att[chunk_index], mask, train_args.chunk) for l, hyp in hyps.items(): start = time.time() l_id = hyp['yseq'] l_end = l_id[-1] vy[0] = l_end prefix_len = len(l_id) if rnnlm: rnnlm_state, local_lm_scores = rnnlm.predict( hyp['rnnlm_prev'], vy) else: rnnlm_state = None local_lm_scores = torch.zeros([1, len(char_list)]) r = lpz[i] * (Pb_prev[l] + Pnb_prev[l]) start = time.time() if 'tmp_att' not in hyp: continue #Todo check why local_att_scores = hyp['tmp_att'] new_cache = hyp['tmp_cache'] align = [0] * prefix_len align[:prefix_len - 1] = hyp['last_time'][:] align[-1] = i pos = hyp['pos'] if 0 in pos or l_end in pos: if l not in hyps_ctc: hyps_ctc[l] = {'yseq': l_id} hyps_ctc[l]['rnnlm_prev'] = hyp['rnnlm_prev'] hyps_ctc[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: hyps_ctc[l]['last_time'] = [0] * prefix_len hyps_ctc[l]['last_time'][:] = hyp['last_time'][:] hyps_ctc[l]['last_time'][-1] = i cur_att_scores = hyps_ctc_compute[l][ "tmp_cur_att_scores"] cur_new_cache = hyps_ctc_compute[l][ "tmp_cur_new_cache"] hyps_ctc[l]['att_score'] = hyp['preatt_score'] + \ float(cur_att_scores[0, l_end].data) hyps_ctc[l]['cur_att'] = float( cur_att_scores[0, l_end].data) hyps_ctc[l]['cache'] = cur_new_cache else: if len(hyps_ctc[l]["yseq"]) > 1: hyps_ctc[l]["end"] = True hyps_ctc[l]['last_time'] = [] hyps_ctc[l]['att_score'] = hyp['att_score'] hyps_ctc[l]['cur_att'] = 0 hyps_ctc[l]['cache'] = hyp['cache'] hyps_ctc[l]['prev_score'] = hyp['prev_score'] hyps_ctc[l]['preatt_score'] = hyp['preatt_score'] hyps_ctc[l]['precache'] = hyp['precache'] hyps_ctc[l]['seq'] = hyp['seq'] for c in list(pos): if c == 0: Pb[l] += lpz[i][0] * (Pb_prev[l] + Pnb_prev[l]) else: l_plus = l + " " + char_list[c] if l_plus not in hyps_ctc: hyps_ctc[l_plus] = {} if "end" in hyp: hyps_ctc[l_plus]['yseq'] = True hyps_ctc[l_plus]['yseq'] = [0] * (prefix_len + 1) hyps_ctc[l_plus]['yseq'][:len(hyp['yseq'])] = l_id hyps_ctc[l_plus]['yseq'][-1] = int(c) hyps_ctc[l_plus]['rnnlm_prev'] = rnnlm_state hyps_ctc[l_plus][ 'rnnlm_score'] = hyp['rnnlm_score'] + float( local_lm_scores[0, c].data) hyps_ctc[l_plus]['att_score'] = hyp['att_score'] \ + float(local_att_scores[0, c].data) hyps_ctc[l_plus]['cur_att'] = float( local_att_scores[0, c].data) hyps_ctc[l_plus]['cache'] = new_cache hyps_ctc[l_plus]['precache'] = hyp['cache'] hyps_ctc[l_plus]['preatt_score'] = hyp['att_score'] hyps_ctc[l_plus]['prev_score'] = hyp['score'] hyps_ctc[l_plus]['last_time'] = align hyps_ctc[l_plus]['rule_penalty'] = 0 hyps_ctc[l_plus]['seq'] = l_plus if l_end != self.eos and c == l_end: Pnb[l_plus] += lpz[i][l_end] * Pb_prev[l] Pnb[l] += lpz[i][l_end] * Pnb_prev[l] else: Pnb[l_plus] += r[c] if l_plus not in hyps: Pb[l_plus] += lpz[i][0] * (Pb_prev[l_plus] + Pnb_prev[l_plus]) Pnb[l_plus] += lpz[i][c] * Pnb_prev[l_plus] #total_copy += time.time() - start for l in hyps_ctc.keys(): if Pb[l] != 0 or Pnb[l] != 0: hyps_ctc[l]['ctc_score'] = np.log(Pb[l] + Pnb[l]) else: hyps_ctc[l]['ctc_score'] = float('-inf') local_score = hyps_ctc[l]['ctc_score'] + recog_args.ctc_lm_weight * hyps_ctc[l]['rnnlm_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) hyps_ctc[l]['local_score'] = local_score hyps_ctc[l]['score'] = (1-recog_args.ctc_weight) * hyps_ctc[l]['att_score'] \ + recog_args.ctc_weight * hyps_ctc[l]['ctc_score'] + \ recog_args.penalty * (len(hyps_ctc[l]['yseq'])) + \ recog_args.lm_weight * hyps_ctc[l]['rnnlm_score'] Pb_prev = Pb Pnb_prev = Pnb Pb = Counter() Pnb = Counter() hyps1 = sorted(hyps_ctc.items(), key=lambda x: x[1]['local_score'], reverse=True)[:beam] hyps1 = dict(hyps1) hyps2 = sorted(hyps_ctc.items(), key=lambda x: x[1]['att_score'], reverse=True)[:beam] hyps2 = dict(hyps2) hyps = sorted(hyps_ctc.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) for key in hyps1.keys(): if key not in hyps: hyps[key] = hyps1[key] for key in hyps2.keys(): if key not in hyps: hyps[key] = hyps2[key] hyps = sorted(hyps.items(), key=lambda x: x[1]['score'], reverse=True)[:beam] hyps = dict(hyps) logging.info('input lengths: ' + str(h.size(0))) logging.info('max output length: ' + str(maxlen)) logging.info('min output length: ' + str(minlen)) if "<eos>" in hyps.keys(): del hyps["<eos>"] #for key in hyps.keys(): # logging.info("{0}\tctc:{1}\tatt:{2}\trnnlm:{3}\tscore:{4}".format(key,hyps[key]["ctc_score"],hyps[key]['att_score'], # hyps[key]['rnnlm_score'], hyps[key]['score'])) # print("!!!","Decoding None") best = list(hyps.keys())[0] ids = hyps[best]['yseq'] score = hyps[best]['score'] logging.info('score: ' + str(score)) #if l in hyps.keys(): # logging.info(l) #print(samelen,h_len) return best, ids, score def removeIlegal(self, hyps): max_y = max([len(hyp['yseq']) for l, hyp in hyps.items()]) for_remove = [] for l, hyp in hyps.items(): if max_y - len(hyp['yseq']) > 4: for_remove.append(l) for cur_str in for_remove: del hyps[cur_str] def clusterbyLength(self, hyps): tmp = {} for l, hyp in hyps.items(): prefix_len = len(hyp['yseq']) if prefix_len > 1 and hyp['yseq'][-1] == self.eos: continue else: if prefix_len not in tmp: tmp[prefix_len] = [] tmp[prefix_len].append(hyp) return tmp def compute_hyps(self, current_hyps, curren_frame, total_frame, enc_output, hat_att, enc_mask, chunk=True): for length, hyps_t in current_hyps.items(): ys_mask = subsequent_mask(length).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) # print(ys_mask4use.shape) l_id = [hyp_t['yseq'] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if hyps_t[0]["cache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["cache"])): current_cache = [] for hyp_t in hyps_t: current_cache.append( hyp_t["cache"][decode_num].squeeze(0)) # print( torch.stack(current_cache).shape) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time'])+1, enc_mask.shape[1]]).byte()) align = [0] * length align[:length - 1] = hyp_t['last_time'][:] align[-1] = curren_frame align_tensor = torch.tensor(align).unsqueeze(0) if chunk: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = self.decoder.forward_one_step( ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_att'] = local_att_scores_b[idx].unsqueeze(0) hat_att[hyp_t['seq']] = {} hat_att[hyp_t['seq']]['cache'] = hyp_t['tmp_cache'] hat_att[hyp_t['seq']]['att_scores'] = hyp_t['tmp_att'] def get_ctchyps2compute(self, hyps, hyps_ctc, current_frame): tmp2 = {} for l, hyp in hyps.items(): l_id = hyp['yseq'] l_end = l_id[-1] if "pos" not in hyp: continue if 0 in hyp['pos'] or l_end in hyp['pos']: #l_minus = ' '.join(l.split()[:-1]) #if l_minus in hat_att: # hyps[l]['tmp_cur_new_cache'] = hat_att[l_minus]['cache'] # hyps[l]['tmp_cur_att_scores'] = hat_att[l_minus]['att_scores'] # continue if l not in hyps_ctc and l_end != self.eos: tmp2[l] = {'yseq': l_id} tmp2[l]['seq'] = l tmp2[l]['rnnlm_prev'] = hyp['rnnlm_prev'] tmp2[l]['rnnlm_score'] = hyp['rnnlm_score'] if l_end != self.eos: tmp2[l]['last_time'] = [0] * len(l_id) tmp2[l]['last_time'][:] = hyp['last_time'][:] tmp2[l]['last_time'][-1] = current_frame return tmp2 def compute_hyps_ctc(self, hyps_ctc_cluster, total_frame, enc_output, hat_att, enc_mask, chunk=True): for length, hyps_t in hyps_ctc_cluster.items(): ys_mask = subsequent_mask(length - 1).unsqueeze(0).cuda() ys_mask4use = ys_mask.repeat(len(hyps_t), 1, 1) l_id = [hyp_t['yseq'][:-1] for hyp_t in hyps_t] ys4use = torch.tensor(l_id).cuda() enc_output4use = enc_output.repeat(len(hyps_t), 1, 1) if "precache" not in hyps_t[0] or hyps_t[0]["precache"] is None: cache4use = None else: cache4use = [] for decode_num in range(len(hyps_t[0]["precache"])): current_cache = [] for hyp_t in hyps_t: # print(length, hyp_t["yseq"], hyp_t["cache"][0].shape, # hyp_t["cache"][2].shape, hyp_t["cache"][4].shape) current_cache.append( hyp_t["precache"][decode_num].squeeze(0)) current_cache = torch.stack(current_cache) cache4use.append(current_cache) partial_mask4use = [] for hyp_t in hyps_t: #partial_mask4use.append(torch.ones([1, len(hyp_t['last_time']), enc_mask.shape[1]]).byte()) align = hyp_t['last_time'] align_tensor = torch.tensor(align).unsqueeze(0) if chunk: partial_mask = enc_mask[0][align_tensor] else: right_window = self.right_window partial_mask = trigger_mask(1, total_frame, align_tensor, self.left_window, right_window) partial_mask4use.append(partial_mask) partial_mask4use = torch.stack(partial_mask4use).cuda().squeeze(1) local_att_scores_b, new_cache_b = \ self.decoder.forward_one_step(ys4use, ys_mask4use, enc_output4use, partial_mask4use, cache4use) for idx, hyp_t in enumerate(hyps_t): hyp_t['tmp_cur_new_cache'] = [ new_cache_b[decode_num][idx].unsqueeze(0) for decode_num in range(len(new_cache_b)) ] hyp_t['tmp_cur_att_scores'] = local_att_scores_b[ idx].unsqueeze(0) l_minus = ' '.join(hyp_t['seq'].split()[:-1]) hat_att[l_minus] = {} hat_att[l_minus]['att_scores'] = hyp_t['tmp_cur_att_scores'] hat_att[l_minus]['cache'] = hyp_t['tmp_cur_new_cache']
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate) self.pad = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode='st', arch='transformer') self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = getattr(args, "asr_weight", 0.0) if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = getattr(args, "mt_weight", 0.0) if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer='embed', dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0) self.reset_parameters(args) # place after the submodule initialization if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if self.asr_weight > 0 and (args.report_cer or args.report_wer): from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False) if self.multilingual: assert self.replace_sos
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) if args.mtlalpha < 1: self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) else: self.decoder = None self.blank = 0 self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.cn_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.en_encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate ) self.decoder = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate ) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = [1] self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss(self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss) # self.verbose = args.verbose self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculator self.error_calculator = ErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer) else: self.error_calculator = None self.rnnlm = None # yzl23 config self.interp_factor = args.interpolation_coe logging.warning("Interpolated moe_coes with {}".format(self.interp_factor)) self.remove_blank_in_ctc_mode = True self.reset_parameters(args) # reset params at the last logging.warning("Model total size: {}M, requires_grad size: {}M" .format(self.count_parameters(), self.count_parameters(requires_grad=True)))
def __init__(self, idim, odim, args, ignore_id=-1, blank_id=0): """Construct an E2E object for transducer model. Args: idim (int): dimension of inputs odim (int): dimension of outputs args (Namespace): argument Namespace containing options """ torch.nn.Module.__init__(self) if args.etype == 'transformer': self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate_encoder) self.subsample = [1] else: self.subsample = get_subsample(args, mode='asr', arch='rnn-t') self.encoder = encoder_for(args, idim, self.subsample) if args.dtype == 'transformer': self.decoder = Decoder( odim=odim, jdim=args.joint_dim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=args.transformer_dec_input_layer, dropout_rate=args.dropout_rate_decoder, positional_dropout_rate=args.dropout_rate_decoder, attention_dropout_rate=args.transformer_attn_dropout_rate_decoder) else: if args.etype == 'transformer': args.eprojs = args.adim if args.rnnt_mode == 'rnnt-att': self.att = att_for(args) self.decoder = decoder_for(args, odim, self.att) else: self.decoder = decoder_for(args, odim) self.etype = args.etype self.dtype = args.dtype self.rnnt_mode = args.rnnt_mode self.sos = odim - 1 self.eos = odim - 1 self.blank_id = blank_id self.ignore_id = ignore_id self.space = args.sym_space self.blank = args.sym_blank self.odim = odim self.adim = args.adim self.reporter = Reporter() self.criterion = TransLoss(args.trans_type, self.blank_id) self.default_parameters(args) if args.report_cer or args.report_wer: from espnet.nets.e2e_asr_common import ErrorCalculatorTrans self.error_calculator = ErrorCalculatorTrans(self.decoder, args) else: self.error_calculator = None self.logzero = -10000000000.0 self.loss = None self.rnnlm = None
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) # fill missing arguments for compatibility args = fill_missing_args(args, self.add_arguments) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) self.pad = 0 # use <blank> for padding self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="st", arch="transformer") self.reporter = Reporter() self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # submodule for ASR task self.mtlalpha = args.mtlalpha self.asr_weight = args.asr_weight if self.asr_weight > 0 and args.mtlalpha < 1: self.decoder_asr = Decoder( odim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, ) # submodule for MT task self.mt_weight = args.mt_weight if self.mt_weight > 0: self.encoder_mt = Encoder( idim=odim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer="embed", dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, padding_idx=0, ) self.reset_parameters( args) # NOTE: place after the submodule initialization self.adim = args.adim # used for CTC (equal to d_model) if self.asr_weight > 0 and args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None # translation error calculator self.error_calculator = MTErrorCalculator(args.char_list, args.sym_space, args.sym_blank, args.report_bleu) # recognition error calculator self.error_calculator_asr = ASRErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) self.rnnlm = None # multilingual E2E-ST related self.multilingual = getattr(args, "multilingual", False) self.replace_sos = getattr(args, "replace_sos", False)
def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking)
def __init__(self, idim, odim, args, ignore_id=-1): """Construct an E2E object. :param int idim: dimension of inputs :param int odim: dimension of outputs :param Namespace args: argument Namespace containing options """ torch.nn.Module.__init__(self) if args.transformer_attn_dropout_rate is None: args.transformer_attn_dropout_rate = args.dropout_rate self.encoder = Encoder( idim=idim, selfattention_layer_type=args. transformer_encoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_encoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.eunits, num_blocks=args.elayers, input_layer=args.transformer_input_layer, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_enc_attn_type', 'self_attn'), max_attn_span=getattr(args, 'enc_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.decoder = Decoder( odim=odim, selfattention_layer_type=args. transformer_decoder_selfattn_layer_type, attention_dim=args.adim, attention_heads=args.aheads, conv_wshare=args.wshare, conv_kernel_length=args.ldconv_decoder_kernel_length, conv_usebias=args.ldconv_usebias, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.dropout_rate, positional_dropout_rate=args.dropout_rate, self_attention_dropout_rate=args.transformer_attn_dropout_rate, src_attention_dropout_rate=args.transformer_attn_dropout_rate, attention_type=getattr(args, 'transformer_dec_attn_type', 'self_attn'), max_attn_span=getattr(args, 'dec_max_attn_span', [None]), span_init=getattr(args, 'span_init', None), span_ratio=getattr(args, 'span_ratio', None), ratio_adaptive=getattr(args, 'ratio_adaptive', None)) self.sos = odim - 1 self.eos = odim - 1 self.odim = odim self.ignore_id = ignore_id self.subsample = get_subsample(args, mode="asr", arch="transformer") self.reporter = Reporter() # self.lsm_weight = a self.criterion = LabelSmoothingLoss( self.odim, self.ignore_id, args.lsm_weight, args.transformer_length_normalized_loss, ) # self.verbose = args.verbose self.reset_parameters(args) self.adim = args.adim self.mtlalpha = args.mtlalpha if args.mtlalpha > 0.0: self.ctc = CTC(odim, args.adim, args.dropout_rate, ctc_type=args.ctc_type, reduce=True) else: self.ctc = None if args.report_cer or args.report_wer: self.error_calculator = ErrorCalculator( args.char_list, args.sym_space, args.sym_blank, args.report_cer, args.report_wer, ) else: self.error_calculator = None self.rnnlm = None self.attention_enc_type = getattr(args, 'transformer_enc_attn_type', 'self_attn') self.attention_dec_type = getattr(args, 'transformer_dec_attn_type', 'self_attn') self.span_loss_coef = getattr(args, 'span_loss_coef', None) self.ratio_adaptive = getattr(args, 'ratio_adaptive', None) self.sym_blank = args.sym_blank
def __init__(self, idim, odim, args=None): """Initialize TTS-Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - embed_dim (int): Dimension of character embedding. - eprenet_conv_layers (int): Number of encoder prenet convolution layers. - eprenet_conv_chans (int): Number of encoder prenet convolution channels. - eprenet_conv_filts (int): Filter size of encoder prenet convolution. - dprenet_layers (int): Number of decoder prenet layers. - dprenet_units (int): Number of decoder prenet hidden units. - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - postnet_layers (int): Number of postnet layers. - postnet_chans (int): Number of postnet channels. - postnet_filts (int): Filter size of postnet. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - use_batch_norm (bool): Whether to use batch normalization in encoder prenet. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - eprenet_dropout_rate (float): Dropout rate in encoder prenet. - dprenet_dropout_rate (float): Dropout rate in decoder prenet. - postnet_dropout_rate (float): Dropout rate in postnet. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). - loss_type (str): How to calculate loss. - use_guided_attn_loss (bool): Whether to use guided attention loss. - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. - modules_applied_guided_attn (list): List of module names to apply guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lambda (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type self.use_scaled_pos_enc = args.use_scaled_pos_enc self.reduction_factor = args.reduction_factor self.loss_type = args.loss_type self.use_guided_attn_loss = args.use_guided_attn_loss if self.use_guided_attn_loss: if args.num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = args.elayers else: self.num_layers_applied_guided_attn = ( args.num_layers_applied_guided_attn) if args.num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = args.aheads else: self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn self.modules_applied_guided_attn = args.modules_applied_guided_attn # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define transformer encoder if args.eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=args.embed_dim, elayers=0, econv_layers=args.eprenet_conv_layers, econv_chans=args.eprenet_conv_chans, econv_filts=args.eprenet_conv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.eprenet_dropout_rate, padding_idx=padding_idx, ), torch.nn.Linear(args.eprenet_conv_chans, args.adim), ) else: encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define transformer decoder if args.dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=args.dprenet_layers, n_units=args.dprenet_units, dropout_rate=args.dprenet_dropout_rate, ), torch.nn.Linear(args.dprenet_units, args.adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=-1, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate, src_attention_dropout_rate=args. transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor) # define postnet self.postnet = (None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate, )) # define loss function self.criterion = TransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model)
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # define encoder encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking )