class Tacotron2(TTSInterface, torch.nn.Module): """VC Tacotron2 module for VC. This is a module of Tacotron2-based VC model, which convert the sequence of acoustic features into the sequence of acoustic features. """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--elayers", default=1, type=int, help="Number of encoder layers" ) group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # attention group.add_argument( "--atype", default="location", type=str, choices=["forward_ta", "forward", "location"], help="Type of attention mechanism", ) group.add_argument( "--adim", default=512, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aconv-chans", default=32, type=int, help="Number of attention convolution channels", ) group.add_argument( "--aconv-filts", default=15, type=int, help="Filter size of attention convolution", ) group.add_argument( "--cumulate-att-w", default=True, type=strtobool, help="Whether or not to cumulate attention weights", ) # decoder group.add_argument( "--dlayers", default=2, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=1024, type=int, help="Number of decoder hidden units" ) group.add_argument( "--prenet-layers", default=2, type=int, help="Number of prenet layers" ) group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument( "--postnet-layers", default=5, type=int, help="Number of postnet layers" ) group.add_argument( "--postnet-chans", default=512, type=int, help="Number of postnet channels" ) group.add_argument( "--postnet-filts", default=5, type=int, help="Filter size of postnet" ) group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # cbhg group.add_argument( "--use-cbhg", default=False, type=strtobool, help="Whether to use CBHG module", ) group.add_argument( "--cbhg-conv-bank-layers", default=8, type=int, help="Number of convoluional bank layers in CBHG", ) group.add_argument( "--cbhg-conv-bank-chans", default=128, type=int, help="Number of convoluional bank channles in CBHG", ) group.add_argument( "--cbhg-conv-proj-filts", default=3, type=int, help="Filter size of convoluional projection layer in CBHG", ) group.add_argument( "--cbhg-conv-proj-chans", default=256, type=int, help="Number of convoluional projection channels in CBHG", ) group.add_argument( "--cbhg-highway-layers", default=4, type=int, help="Number of highway layers in CBHG", ) group.add_argument( "--cbhg-highway-units", default=128, type=int, help="Number of highway units in CBHG", ) group.add_argument( "--cbhg-gru-units", default=256, type=int, help="Number of GRU units in CBHG", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help="Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument( "--dropout-rate", default=0.5, type=float, help="Dropout rate" ) group.add_argument( "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor (for decoder)", ) group.add_argument( "--encoder-reduction-factor", default=1, type=int, help="Reduction factor (for encoder)", ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" ) group.add_argument( "--pretrained-model", default=None, type=str, help="Pretrained model path" ) # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--bce-pos-weight", default=20.0, type=float, help="Positive sample weight in BCE calculation " "(only for use-masking=True)", ) group.add_argument( "--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss", ) group.add_argument( "--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss", ) group.add_argument( "--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss", ) group.add_argument( "--src-reconstruction-loss-lambda", default=1.0, type=float, help="Lambda in source reconstruction loss", ) group.add_argument( "--trg-reconstruction-loss-lambda", default=1.0, type=float, help="Lambda in target reconstruction loss", ) return parser def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to mask padded part in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.adim = args.adim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.encoder_reduction_factor = args.encoder_reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError( "there is no such an activation function. (%s)" % args.output_activation ) # define network modules self.enc = Encoder( idim=idim * args.encoder_reduction_factor, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) dec_idim = ( args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim ) if args.atype == "location": att = AttLoc( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) elif args.atype == "forward": att = AttForward( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim, ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG( idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units, ) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) if self.src_reconstruction_loss_lambda > 0: self.src_reconstructor = Encoder( idim=dec_idim, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) self.src_reconstructor_linear = torch.nn.Linear( args.econv_chans, idim * args.encoder_reduction_factor ) self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) if self.trg_reconstruction_loss_lambda > 0: self.trg_reconstructor = Encoder( idim=dec_idim, input_layer="linear", elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, ) self.trg_reconstructor_linear = torch.nn.Linear( args.econv_chans, odim * args.reduction_factor ) self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) def forward( self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs ): """Calculate forward propagation. Args: xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). spcs (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: B, Lmax, idim = xs.shape if Lmax % self.encoder_reduction_factor != 0: xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] xs_ds = xs.contiguous().view( B, int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) ilens_ds = ilens.new( [ilen // self.encoder_reduction_factor for ilen in ilens] ) else: xs_ds, ilens_ds = xs, ilens # calculate tacotron2 outputs hs, hlens = self.enc(xs_ds, ilens_ds) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # calculate src reconstruction if self.src_reconstruction_loss_lambda > 0: B, _in_length, _adim = hs.shape xt, xtlens = self.src_reconstructor(hs, hlens) xt = self.src_reconstructor_linear(xt) if self.encoder_reduction_factor > 1: xt = xt.view(B, -1, self.idim) # calculate trg reconstruction if self.trg_reconstruction_loss_lambda > 0: olens_trg_cp = olens.new( sorted([olen // self.reduction_factor for olen in olens], reverse=True) ) B, _in_length, _adim = hs.shape _, _out_length, _ = att_ws.shape # att_R should be [B, out_length / r_d, adim] att_R = torch.sum( hs.view(B, 1, _in_length, _adim) * att_ws.view(B, _out_length, _in_length, 1), dim=2, ) yt, ytlens = self.trg_reconstructor( att_R, olens_trg_cp ) # is using olens correct? yt = self.trg_reconstructor_linear(yt) if self.reduction_factor > 1: yt = yt.view( B, -1, self.odim ) # now att_R should be [B, out_length, adim] # modifiy mod part of groundtruth if self.reduction_factor > 1: assert olens.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = torch.scatter( labels, 1, (olens - 1).unsqueeze(1), 1.0 ) # see #3388 if self.encoder_reduction_factor > 1: ilens = ilens.new( [ilen - ilen % self.encoder_reduction_factor for ilen in ilens] ) max_in = max(ilens) xs = xs[:, :max_in] # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens ) loss = l1_loss + mse_loss + bce_loss report_keys = [ {"l1_loss": l1_loss.item()}, {"mse_loss": mse_loss.item()}, {"bce_loss": bce_loss.item()}, ] # calculate context_preservation loss if self.src_reconstruction_loss_lambda > 0: src_recon_l1_loss, src_recon_mse_loss = self.src_reconstruction_loss( xt, xs, ilens ) loss = loss + src_recon_l1_loss report_keys += [ {"src_recon_l1_loss": src_recon_l1_loss.item()}, {"src_recon_mse_loss": src_recon_mse_loss.item()}, ] if self.trg_reconstruction_loss_lambda > 0: trg_recon_l1_loss, trg_recon_mse_loss = self.trg_reconstruction_loss( yt, ys, olens ) loss = loss + trg_recon_l1_loss report_keys += [ {"trg_recon_l1_loss": trg_recon_l1_loss.item()}, {"trg_recon_mse_loss": trg_recon_mse_loss.item()}, ] # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive input # will be changed when r > 1 if self.encoder_reduction_factor > 1: ilens_in = ilens.new( [ilen // self.encoder_reduction_factor for ilen in ilens] ) else: ilens_in = ilens if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens_in, olens_in) loss = loss + attn_loss report_keys += [ {"attn_loss": attn_loss.item()}, ] # calculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != spcs.shape[1]: spcs = spcs[:, :max_out] # calculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ {"cbhg_l1_loss": cbhg_l1_loss.item()}, {"cbhg_mse_loss": cbhg_mse_loss.item()}, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of acoustic features (T, idim). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: Lmax, idim = x.shape if Lmax % self.encoder_reduction_factor != 0: x = x[: -(Lmax % self.encoder_reduction_factor), :] x_ds = x.contiguous().view( int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) else: x_ds = x # inference h = self.enc.inference(x_ds) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded acoustic features (B, Tmax, idim). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): # thin out input frames for reduction factor # (B, Lmax, idim) -> (B, Lmax // r, idim * r) if self.encoder_reduction_factor > 1: B, Lmax, idim = xs.shape if Lmax % self.encoder_reduction_factor != 0: xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :] xs_ds = xs.contiguous().view( B, int(Lmax / self.encoder_reduction_factor), idim * self.encoder_reduction_factor, ) ilens_ds = [ilen // self.encoder_reduction_factor for ilen in ilens] else: xs_ds, ilens_ds = xs, ilens hs, hlens = self.enc(xs_ds, ilens_ds) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() return att_ws.cpu().numpy() @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"] if self.use_guided_attn_loss: plot_keys += ["attn_loss"] if self.use_cbhg: plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"] if self.src_reconstruction_loss_lambda > 0: plot_keys += ["src_recon_l1_loss", "src_recon_mse_loss"] if self.trg_reconstruction_loss_lambda > 0: plot_keys += ["trg_recon_l1_loss", "trg_recon_mse_loss"] return plot_keys def _sort_by_length(self, xs, ilens): sort_ilens, sort_idx = ilens.sort(0, descending=True) return xs[sort_idx], ilens[sort_idx], sort_idx def _revert_sort_by_length(self, xs, ilens, sort_idx): _, revert_idx = sort_idx.sort(0) return xs[revert_idx], ilens[revert_idx]
class Tacotron2(AbsTTS): """Tacotron2 module for end-to-end text-to-speech. This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 Args: idim (int): Dimension of the inputs. odim: (int) Dimension of the outputs. spk_embed_dim (int, optional): Dimension of the speaker embedding. embed_dim (int, optional): Dimension of character embedding. elayers (int, optional): The number of encoder blstm layers. eunits (int, optional): The number of encoder blstm units. econv_layers (int, optional): The number of encoder conv layers. econv_filts (int, optional): The number of encoder conv filter size. econv_chans (int, optional): The number of encoder conv filter channels. dlayers (int, optional): The number of decoder lstm layers. dunits (int, optional): The number of decoder lstm units. prenet_layers (int, optional): The number of prenet layers. prenet_units (int, optional): The number of prenet units. postnet_layers (int, optional): The number of postnet layers. postnet_filts (int, optional): The number of postnet filter size. postnet_chans (int, optional): The number of postnet filter channels. output_activation (str, optional): The name of activation function for outputs. adim (int, optional): The number of dimension of mlp in attention. aconv_chans (int, optional): The number of attention conv filter channels. aconv_filts (int, optional): The number of attention conv filter size. cumulate_att_w (bool, optional): Whether to cumulate previous attention weight. use_batch_norm (bool, optional): Whether to use batch normalization. use_concate (bool, optional): Whether to concatenate encoder embedding with decoder lstm outputs. reduction_factor (int, optional): Reduction factor. spk_embed_dim (int, optional): Number of speaker embedding dimenstions. spk_embed_integration_type (str, optional): How to integrate speaker embedding. use_gst (str, optional): Whether to use global style token. gst_tokens (int, optional): The number of GST embeddings. gst_heads (int, optional): The number of heads in GST multihead attention. gst_conv_layers (int, optional): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int], optional): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. gst_conv_stride (int, optional): Stride size of conv layers in GST. gst_gru_layers (int, optional): The number of GRU layers in GST. gst_gru_units (int, optional): The number of GRU units in GST. dropout_rate (float, optional): Dropout rate. zoneout_rate (float, optional): Zoneout rate. use_masking (bool, optional): Whether to mask padded part in loss calculation. use_weighted_masking (bool, optional): Whether to apply weighted masking in loss calculation. bce_pos_weight (float, optional): Weight of positive sample of stop token (only for use_masking=True). loss_type (str, optional): How to calculate loss. use_guided_attn_loss (bool, optional): Whether to use guided attention loss. guided_attn_loss_sigma (float, optional): Sigma in guided attention loss. guided_attn_loss_lamdba (float, optional): Lambda in guided attention loss. """ def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, elayers: int = 1, eunits: int = 512, econv_layers: int = 3, econv_chans: int = 512, econv_filts: int = 5, atype: str = "location", adim: int = 512, aconv_chans: int = 32, aconv_filts: int = 15, cumulate_att_w: bool = True, dlayers: int = 2, dunits: int = 1024, prenet_layers: int = 2, prenet_units: int = 256, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, output_activation: str = None, use_batch_norm: bool = True, use_concate: bool = True, use_residual: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "concat", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related dropout_rate: float = 0.5, zoneout_rate: float = 0.1, use_masking: bool = True, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1+L2", use_guided_attn_loss: bool = True, guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Tacotron2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.loss_type = loss_type if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError(f"there is no such an activation function. " f"({output_activation})") # set padding idx padding_idx = 0 self.padding_idx = padding_idx # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=eunits, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) if spk_embed_dim is None: dec_idim = eunits elif spk_embed_integration_type == "concat": dec_idim = eunits + spk_embed_dim elif spk_embed_integration_type == "add": dec_idim = eunits self.projection = torch.nn.Linear(self.spk_embed_dim, eunits) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input batch (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = speech_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs, ilens, ys, olens, spembs) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor, olens: torch.Tensor, spembs: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hs, hlens = self.enc(xs, ilens) if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) return self.dec(hs, hlens, ys) def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). threshold (float, optional): Threshold in inference. minlenratio (float, optional): Minimum length ratio in inference. maxlenratio (float, optional): Maximum length ratio in inference. use_att_constraint (bool, optional): Whether to apply attention constraint. backward_window (int, optional): Backward window in attention constraint. forward_window (int, optional): Forward window in attention constraint. use_teacher_forcing (bool, optional): Whether to use teacher forcing. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ x = text y = speech spemb = spembs # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert speech is not None, "speech must be provided with teacher forcing." xs, ys = x.unsqueeze(0), y.unsqueeze(0) spembs = None if spemb is None else spemb.unsqueeze(0) ilens = x.new_tensor([xs.size(1)]).long() olens = y.new_tensor([ys.size(1)]).long() outs, _, _, att_ws = self._forward(xs, ilens, ys, olens, spembs) return outs[0], None, att_ws[0] # inference h = self.enc.inference(x) if self.use_gst: style_emb = self.gst(y.unsqueeze(0)) h = h + style_emb if self.spk_embed_dim is not None: hs, spembs = h.unsqueeze(0), spemb.unsqueeze(0) h = self._integrate_with_spk_embed(hs, spembs)[0] outs, probs, att_ws = self.dec.inference( h, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) return outs, probs, att_ws def inference_( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). threshold (float, optional): Threshold in inference. minlenratio (float, optional): Minimum length ratio in inference. maxlenratio (float, optional): Maximum length ratio in inference. use_att_constraint (bool, optional): Whether to apply attention constraint. backward_window (int, optional): Backward window in attention constraint. forward_window (int, optional): Forward window in attention constraint. use_teacher_forcing (bool, optional): Whether to use teacher forcing. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ record = {23, 33, 70, 72, 76} def get(x, words=3, lookahead=1): l = x.size(0) count = 0 b = [] i = 0 while i < l: if x[i].item() == 1 and i < l - 1 and x[i + 1].item() in record: count += 1 b.append(x[i + 1]) i += 2 elif x[i].item() == 1: count += 1 else: b.append(x[i]) if count == words: if b and (len(b) != 1 or b[0] not in record): b.append(self.eos) chunk_len = len(b) tmp_count = 0 j = i + 1 while j < l: if tmp_count == lookahead: break if x[j].item() == 1 and i < l - 1 and x[ j + 1].item() in record: tmp_count += 1 b.append(x[j + 1]) j += 2 elif x[j].item() == 1: tmp_count += 1 else: b.append(x[j]) j += 1 yield chunk_len, torch.tensor(b).long().cuda() b.clear() count = 0 i += 1 if b and (len(b) != 1 or b[0] not in record): b.append(self.eos) chunk_len = len(b) yield chunk_len, torch.tensor(b).long().cuda() x = text y = speech spemb = spembs # add eos at the last of sequence # x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert speech is not None, "speech must be provided with teacher forcing." xs, ys = x.unsqueeze(0), y.unsqueeze(0) spembs = None if spemb is None else spemb.unsqueeze(0) ilens = x.new_tensor([xs.size(1)]).long() olens = y.new_tensor([ys.size(1)]).long() outs, _, _, att_ws = self._forward(xs, ilens, ys, olens, spembs) return outs[0], None, att_ws[0] # inference words = 3 lookahead = 1 diff = 0 buffer = 100 chunk = torch.tensor([]).long().cuda() for cur_len, cur in get(x, words, lookahead): add = cur.size(0) - diff diff = cur.size(0) - cur_len chunk = torch.cat([chunk, cur]) if chunk.size(0) > buffer: chunk = chunk[-buffer:] h = self.enc.inference(chunk) target_len = chunk.size(0) - diff outs, probs, att_ws = self.dec.inference( h, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, buf=add, target_len=target_len, ) chunk = chunk[:target_len] return outs, probs, att_ws def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, eunits). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, eunits) if integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) else: raise NotImplementedError("support only add or concat.") return hs
class Tacotron2(TTSInterface, torch.nn.Module): """Tacotron2 module for end-to-end text-to-speech (E2E-TTS). This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument("--elayers", default=1, type=int, help="Number of encoder layers") group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # attention group.add_argument( "--atype", default="location", type=str, choices=["forward_ta", "forward", "location"], help="Type of attention mechanism", ) group.add_argument( "--adim", default=512, type=int, help="Number of attention transformation dimensions", ) group.add_argument( "--aconv-chans", default=32, type=int, help="Number of attention convolution channels", ) group.add_argument( "--aconv-filts", default=15, type=int, help="Filter size of attention convolution", ) group.add_argument( "--cumulate-att-w", default=True, type=strtobool, help="Whether or not to cumulate attention weights", ) # decoder group.add_argument("--dlayers", default=2, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1024, type=int, help="Number of decoder hidden units") group.add_argument("--prenet-layers", default=2, type=int, help="Number of prenet layers") group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=512, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # cbhg group.add_argument( "--use-cbhg", default=False, type=strtobool, help="Whether to use CBHG module", ) group.add_argument( "--cbhg-conv-bank-layers", default=8, type=int, help="Number of convoluional bank layers in CBHG", ) group.add_argument( "--cbhg-conv-bank-chans", default=128, type=int, help="Number of convoluional bank channles in CBHG", ) group.add_argument( "--cbhg-conv-proj-filts", default=3, type=int, help="Filter size of convoluional projection layer in CBHG", ) group.add_argument( "--cbhg-conv-proj-chans", default=256, type=int, help="Number of convoluional projection channels in CBHG", ) group.add_argument( "--cbhg-highway-layers", default=4, type=int, help="Number of highway layers in CBHG", ) group.add_argument( "--cbhg-highway-units", default=128, type=int, help="Number of highway units in CBHG", ) group.add_argument( "--cbhg-gru-units", default=256, type=int, help="Number of GRU units in CBHG", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help= "Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument("--dropout-rate", default=0.5, type=float, help="Dropout rate") group.add_argument("--zoneout-rate", default=0.1, type=float, help="Zoneout rate") group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--char-embed-dim", default=None, type=int, help="Number of character embedding dimensions", ) group.add_argument("--spc-dim", default=None, type=int, help="Number of spectrogram dimensions") group.add_argument("--pretrained-model", default=None, type=str, help="Pretrained model path") # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) group.add_argument( "--bce-pos-weight", default=20.0, type=float, help="Positive sample weight in BCE calculation " "(only for use-masking=True)", ) group.add_argument( "--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss", ) group.add_argument( "--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss", ) group.add_argument( "--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss", ) return parser def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.char_embed_dim = args.char_embed_dim self.into_embed_dim = args.into_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss self.use_intotype_loss = args.use_intotype_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError("there is no such an activation function. (%s)" % args.output_activation) # set padding idx padding_idx = 0 # define network modules enc_extra_dim = 0 if args.char_embed_dim is not None and args.character_embedding_position in [ 'encoder', 'both' ]: enc_extra_dim = args.eunits self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, extra_dim=enc_extra_dim, ) self.pre_enc = None self.ch_enc = None chenc_type = CharacterEncoder chenc_odim = args.eunits if args.character_encoder_type == 'transformer': chenc_type = SentenceEncoder chenc_odim = 256 if args.char_embed_dim is not None: if args.character_embedding_position == 'encoder': self.pre_enc = chenc_type( idim=args.char_embed_dim, pred_into_type=args.use_intotype_loss, into_type_num=args.into_type_num, reduce_character_embedding=args.reduce_character_embedding, elayers=args.elayers, eunits=args.eunits, ) elif args.character_embedding_position == 'decoder': self.ch_enc = chenc_type( idim=args.char_embed_dim, pred_into_type=args.use_intotype_loss, into_type_num=args.into_type_num, reduce_character_embedding=args.reduce_character_embedding, elayers=args.elayers, eunits=args.eunits, ) elif args.character_embedding_position == 'both': self.pre_enc = chenc_type( idim=args.char_embed_dim, pred_into_type=args.use_intotype_loss, into_type_num=args.into_type_num, reduce_character_embedding=args.reduce_character_embedding, elayers=args.elayers, eunits=args.eunits, ) self.ch_enc = chenc_type( idim=args.char_embed_dim, pred_into_type=False, into_type_num=0, reduce_character_embedding=False, elayers=args.elayers, eunits=args.eunits, ) else: raise ValueError( "Invalid character embedding position \"%s\"" % args.character_embedding_position) if args.into_embed_dim is not None: self.into_embed = self.embed = torch.nn.Embedding( args.into_type_num, args.into_embed_dim, padding_idx=padding_idx, ) dec_idim = args.eunits if args.spk_embed_dim: dec_idim += args.spk_embed_dim if self.ch_enc is not None: dec_idim += chenc_odim if args.into_embed_dim: dec_idim += args.into_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA( dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim, ) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, bce_pos_weight=args.bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_intotype_loss: self.intotype_loss = IntoTypeLoss(args.into_type_num, ) if self.use_cbhg: self.cbhg = CBHG( idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units, ) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) def expand_to(self, xs, lens): """ xs: (B, D) lens: (B,) """ # (B, T, 1) mask = to_device(xs, make_pad_mask(lens).unsqueeze(-1)) # (B, D) -> (B, 1, D) -> (B, T, D) xs = xs.unsqueeze(1).expand(-1, mask.size(1), -1).masked_fill(mask, 0.0) return xs def forward(self, xs, ilens, ys, labels, olens, chembs=None, chlens=None, intotypes=None, spembs=None, extras=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs pre_xs = None if self.pre_enc is not None: pre_xs, _, pre_type_logits = self.pre_enc(chembs, chlens) if pre_xs.ndim != hs.ndim: pre_xs = self.expand_to(pre_xs, ilens) hs, hlens = self.enc(xs, ilens, pre_xs) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) if self.ch_enc is not None: ch_hs, _, ch_type_logits = self.ch_enc(chembs, chlens) if ch_hs.ndim != hs.ndim: ch_hs = self.expand_to(ch_hs, ilens) hs = torch.cat([hs, ch_hs], dim=-1) if self.into_embed_dim is not None: itembs = self.into_embed(intotypes).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, itembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) loss = l1_loss + mse_loss + bce_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "mse_loss": mse_loss.item() }, { "bce_loss": bce_loss.item() }, ] # caluculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): # length of output for auto-regressive input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss report_keys += [ { "attn_loss": attn_loss.item() }, ] if self.use_intotype_loss: type_logits = pre_type_logits if self.pre_enc is not None else ch_type_logits it_loss = self.intotype_loss(type_logits, intotypes) loss = loss + it_loss report_keys += [{"intonation_type_loss": it_loss.item()}] # caluculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != extras.shape[1]: extras = extras[:, :max_out] # caluculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss( cbhg_outs, extras, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ { "cbhg_l1_loss": cbhg_l1_loss.item() }, { "cbhg_mse_loss": cbhg_mse_loss.item() }, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, chemb=None, intotype=None, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio use_att_constraint = getattr(inference_args, "use_att_constraint", False) # keep compatibility backward_window = inference_args.backward_window if use_att_constraint else 0 forward_window = inference_args.forward_window if use_att_constraint else 0 # inference pre_x = None if self.pre_enc is not None: pre_x, pre_type_logit = self.pre_enc.inference(chemb) if pre_x.ndim != x.ndim: pre_x = self.expand_to(pre_x.unsqueeze(0), torch.tensor([x.size(0)])).squeeze(0) # To print prediction of intonation types if pre_type_logit is not None: pre_type_logit = pre_type_logit.data.cpu().numpy() print(pre_type_logit) print(pre_type_logit.argmax()) # ============ h = self.enc.inference(x, pre_x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) if self.ch_enc is not None: ch_h, ch_type_logit = self.ch_enc.inference(chemb) if ch_h.ndim != h.ndim: ch_h = self.expand_to(ch_h.unsqueeze(0), torch.tensor([x.size(0)])).squeeze(0) # To print prediction of intonation types if ch_type_logit is not None: ch_type_logit = ch_type_logit.data.cpu().numpy() print(ch_type_logit) print(ch_type_logit.argmax()) # ============ h = torch.cat([h, ch_h], dim=-1) if self.into_embed_dim: itemb = self.into_embed(intotype).unsqueeze(0).expand( h.size(0), -1) h = torch.cat([h, itemb], dim=-1) outs, probs, att_ws = self.dec.inference( h, threshold, minlenratio, maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) if self.use_cbhg: outs = self.cbhg.inference(outs) return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, chembs=None, chlens=None, intotypes=None, spembs=None, keep_tensor=False, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). keep_tensor (bool, optional): Whether to keep original tensor. Returns: Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax). """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): pre_xs = None if self.pre_enc is not None: pre_xs, _, pre_type_logits = self.pre_enc(chembs, chlens) if pre_xs.ndim != hs.ndim: pre_xs = self.expand_to(pre_xs, ilens) hs, hlens = self.enc(xs, ilens, pre_xs) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) if self.ch_enc is not None: ch_hs, _, ch_type_logits = self.ch_enc(chembs, chlens) if ch_hs.ndim != hs.ndim: ch_hs = self.expand_to(ch_hs, ilens) hs = torch.cat([hs, ch_hs], dim=-1) if self.into_embed_dim is not None: itembs = self.into_embed(intotypes).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, itembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() if keep_tensor: return att_ws else: return att_ws.cpu().numpy() @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"] if self.use_guided_attn_loss: plot_keys += ["attn_loss"] if self.use_cbhg: plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"] return plot_keys def gta_inference(self, xs, ilens, ys, labels, olens, chembs=None, intotypes=None, spembs=None, extras=None, *args, **kwargs): max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs pre_xs = None if self.pre_enc is not None: pre_xs, _, pre_type_logits = self.pre_enc(chembs, ilens) hs, hlens = self.enc(xs, ilens, pre_xs) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) if self.ch_enc is not None: ch_hs, _, ch_type_logits = self.ch_enc(chembs, ilens) hs = torch.cat([hs, ch_hs], dim=-1) if self.into_embed_dim is not None: itembs = self.into_embed(intotypes).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, itembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) return after_outs
class Tacotron2(TTSInterface, torch.nn.Module): """Tacotron2 based Seq2Seq converts chars to features Reference: Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions (https://arxiv.org/abs/1712.05884) :param int idim: dimension of the inputs :param int odim: dimension of the outputs :param Namespace args: argments containing following attributes (int) spk_embed_dim: dimension of the speaker embedding (int) embed_dim: dimension of character embedding (int) elayers: the number of encoder blstm layers (int) eunits: the number of encoder blstm units (int) econv_layers: the number of encoder conv layers (int) econv_filts: the number of encoder conv filter size (int) econv_chans: the number of encoder conv filter channels (int) dlayers: the number of decoder lstm layers (int) dunits: the number of decoder lstm units (int) prenet_layers: the number of prenet layers (int) prenet_units: the number of prenet units (int) postnet_layers: the number of postnet layers (int) postnet_filts: the number of postnet filter size (int) postnet_chans: the number of postnet filter channels (str) output_activation: the name of activation function for outputs (int) adim: the number of dimension of mlp in attention (int) aconv_chans: the number of attention conv filter channels (int) aconv_filts: the number of attention conv filter size (bool) cumulate_att_w: whether to cumulate previous attention weight (bool) use_batch_norm: whether to use batch normalization (bool) use_concate: whether to concatenate encoder embedding with decoder lstm outputs (float) dropout_rate: dropout rate (float) zoneout_rate: zoneout rate (int) reduction_factor: reduction factor (bool) use_cbhg: whether to use CBHG module (int) cbhg_conv_bank_layers: the number of convoluional banks in CBHG (int) cbhg_conv_bank_chans: the number of channels of convolutional bank in CBHG (int) cbhg_proj_filts: the number of filter size of projection layeri in CBHG (int) cbhg_proj_chans: the number of channels of projection layer in CBHG (int) cbhg_highway_layers: the number of layers of highway network in CBHG (int) cbhg_highway_units: the number of units of highway network in CBHG (int) cbhg_gru_units: the number of units of GRU in CBHG (bool) use_masking: whether to mask padded part in loss calculation (float) bce_pos_weight: weight of positive sample of stop token (only for use_masking=True) """ @staticmethod def add_arguments(parser): # encoder parser.add_argument('--embed-dim', default=512, type=int, help='Number of dimension of embedding') parser.add_argument('--elayers', default=1, type=int, help='Number of encoder layers') parser.add_argument('--eunits', '-u', default=512, type=int, help='Number of encoder hidden units') parser.add_argument('--econv-layers', default=3, type=int, help='Number of encoder convolution layers') parser.add_argument('--econv-chans', default=512, type=int, help='Number of encoder convolution channels') parser.add_argument('--econv-filts', default=5, type=int, help='Filter size of encoder convolution') # attention parser.add_argument('--atype', default="location", type=str, choices=["forward_ta", "forward", "location"], help='Type of attention mechanism') parser.add_argument( '--adim', default=512, type=int, help='Number of attention transformation dimensions') parser.add_argument('--aconv-chans', default=32, type=int, help='Number of attention convolution channels') parser.add_argument('--aconv-filts', default=15, type=int, help='Filter size of attention convolution') parser.add_argument( '--cumulate-att-w', default=True, type=strtobool, help="Whether or not to cumulate attention weights") # decoder parser.add_argument('--dlayers', default=2, type=int, help='Number of decoder layers') parser.add_argument('--dunits', default=1024, type=int, help='Number of decoder hidden units') parser.add_argument('--prenet-layers', default=2, type=int, help='Number of prenet layers') parser.add_argument('--prenet-units', default=256, type=int, help='Number of prenet hidden units') parser.add_argument('--postnet-layers', default=5, type=int, help='Number of postnet layers') parser.add_argument('--postnet-chans', default=512, type=int, help='Number of postnet channels') parser.add_argument('--postnet-filts', default=5, type=int, help='Filter size of postnet') parser.add_argument('--output-activation', default=None, type=str, nargs='?', help='Output activation function') # cbhg parser.add_argument('--use-cbhg', default=False, type=strtobool, help='Whether to use CBHG module') parser.add_argument('--cbhg-conv-bank-layers', default=8, type=int, help='Number of convoluional bank layers in CBHG') parser.add_argument( '--cbhg-conv-bank-chans', default=128, type=int, help='Number of convoluional bank channles in CBHG') parser.add_argument( '--cbhg-conv-proj-filts', default=3, type=int, help='Filter size of convoluional projection layer in CBHG') parser.add_argument( '--cbhg-conv-proj-chans', default=256, type=int, help='Number of convoluional projection channels in CBHG') parser.add_argument('--cbhg-highway-layers', default=4, type=int, help='Number of highway layers in CBHG') parser.add_argument('--cbhg-highway-units', default=128, type=int, help='Number of highway units in CBHG') parser.add_argument('--cbhg-gru-units', default=256, type=int, help='Number of GRU units in CBHG') # model (parameter) related parser.add_argument('--use-batch-norm', default=True, type=strtobool, help='Whether to use batch normalization') parser.add_argument( '--use-concate', default=True, type=strtobool, help='Whether to concatenate encoder embedding with decoder outputs' ) parser.add_argument( '--use-residual', default=True, type=strtobool, help='Whether to use residual connection in conv layer') parser.add_argument('--dropout-rate', default=0.5, type=float, help='Dropout rate') parser.add_argument('--zoneout-rate', default=0.1, type=float, help='Zoneout rate') parser.add_argument('--reduction-factor', default=1, type=int, help='Reduction factor') # loss related parser.add_argument( '--use-masking', default=False, type=strtobool, help='Whether to use masking in calculation of loss') parser.add_argument( '--bce-pos-weight', default=20.0, type=float, help= 'Positive sample weight in BCE calculation (only for use-masking=True)' ) parser.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") parser.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") return def __init__(self, idim, odim, args): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = getattr(args, "use_guided_attn_loss", False) # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError('there is no such an activation function. (%s)' % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder(idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.dropout_rate, padding_idx=padding_idx) dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder(idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor) self.taco2_loss = Tacotron2Loss(args) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma) if self.use_cbhg: self.cbhg = CBHG(idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units) self.cbhg_loss = CBHGLoss(args) def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs): """Tacotron2 forward computation :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor olens: batch of the lengths of each target (B) :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim) :param torch.Tensor spcs: batch of groundtruth spectrogram (B, Lmax, spc_dim) :return: loss value :rtype: torch.Tensor """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) loss = l1_loss + mse_loss + bce_loss report_keys = [ { 'l1_loss': l1_loss.item() }, { 'mse_loss': mse_loss.item() }, { 'bce_loss': bce_loss.item() }, ] # caluculate attention loss if self.use_guided_attn_loss: attn_loss = self.attn_loss(att_ws, ilens, olens) loss = loss + attn_loss report_keys += [ { 'attn_loss': attn_loss.item() }, ] # caluculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != spcs.shape[1]: spcs = spcs[:, :max_out] # caluculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss( cbhg_outs, spcs, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ { 'cbhg_l1_loss': cbhg_l1_loss.item() }, { 'cbhg_mse_loss': cbhg_mse_loss.item() }, ] report_keys += [{'loss': loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generates the sequence of features given the sequences of characters :param torch.Tensor x: the sequence of characters (T) :param Namespace inference_args: argments containing following attributes (float) threshold: threshold in inference (float) minlenratio: minimum length ratio in inference (float) maxlenratio: maximum length ratio in inference :param torch.Tensor spemb: speaker embedding vector (spk_embed_dim) :return: the sequence of features (L, odim) :rtype: torch.Tensor :return: the sequence of stop probabilities (L) :rtype: torch.Tensor :return: the sequence of attention weight (L, T) :rtype: torch.Tensor """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # inference h = self.enc.inference(x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs): """Tacotron2 attention weight computation :param torch.Tensor xs: batch of padded character ids (B, Tmax) :param torch.Tensor ilens: list of lengths of each input batch (B) :param torch.Tensor ys: batch of padded target features (B, Lmax, odim) :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim) :return: attention weights (B, Lmax, Tmax) :rtype: numpy array """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() return att_ws.cpu().numpy() @property def base_plot_keys(self): """base key names to plot during training. keys should match what `chainer.reporter` reports if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. :rtype list[str] plot_keys: base keys to plot during training """ plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss'] if self.use_guided_attn_loss: plot_keys += ['attn_loss'] if self.use_cbhg: plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss'] return plot_keys
class Tacotron2(AbsTTS): """Tacotron2 module for end-to-end text-to-speech. This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, elayers: int = 1, eunits: int = 512, econv_layers: int = 3, econv_chans: int = 512, econv_filts: int = 5, atype: str = "location", adim: int = 512, aconv_chans: int = 32, aconv_filts: int = 15, cumulate_att_w: bool = True, dlayers: int = 2, dunits: int = 1024, prenet_layers: int = 2, prenet_units: int = 256, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, output_activation: str = None, use_batch_norm: bool = True, use_concate: bool = True, use_residual: bool = False, reduction_factor: int = 1, # extra embedding related spks: int = -1, langs: int = -1, spk_embed_dim: int = None, spk_embed_integration_type: str = "concat", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related dropout_rate: float = 0.5, zoneout_rate: float = 0.1, use_masking: bool = True, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1+L2", use_guided_attn_loss: bool = True, guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim: (int) Dimension of the outputs. embed_dim (int): Dimension of the token embedding. elayers (int): Number of encoder blstm layers. eunits (int): Number of encoder blstm units. econv_layers (int): Number of encoder conv layers. econv_filts (int): Number of encoder conv filter size. econv_chans (int): Number of encoder conv filter channels. dlayers (int): Number of decoder lstm layers. dunits (int): Number of decoder lstm units. prenet_layers (int): Number of prenet layers. prenet_units (int): Number of prenet units. postnet_layers (int): Number of postnet layers. postnet_filts (int): Number of postnet filter size. postnet_chans (int): Number of postnet filter channels. output_activation (str): Name of activation function for outputs. adim (int): Number of dimension of mlp in attention. aconv_chans (int): Number of attention conv filter channels. aconv_filts (int): Number of attention conv filter size. cumulate_att_w (bool): Whether to cumulate previous attention weight. use_batch_norm (bool): Whether to use batch normalization. use_concate (bool): Whether to concat enc outputs w/ dec lstm outputs. reduction_factor (int): Reduction factor. spks: Number of speakers. If set to > 0, speaker ID embedding will be used. langs: Number of langs. If set to > 0, lang ID embedding will be used. spk_embed_dim (int): Pretrained speaker embedding dimension. spk_embed_integration_type (str): How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): Number of GST embeddings. gst_heads (int): Number of heads in GST multihead attention. gst_conv_layers (int): Number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): Number of GRU layers in GST. gst_gru_units (int): Number of GRU units in GST. dropout_rate (float): Dropout rate. zoneout_rate (float): Zoneout rate. use_masking (bool): Whether to mask padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). loss_type (str): Loss function type ("L1", "L2", or "L1+L2"). use_guided_attn_loss (bool): Whether to use guided attention loss. guided_attn_loss_sigma (float): Sigma in guided attention loss. guided_attn_loss_lambda (float): Lambda in guided attention loss. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor self.spks = spks self.langs = langs self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.loss_type = loss_type if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError( f"there is no such an activation function. " f"({output_activation})" ) # set padding idx padding_idx = 0 self.padding_idx = padding_idx # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=eunits, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) if self.spks > 0: self.sid_emb = torch.nn.Embedding(spks, embed_dim) if self.langs > 0: self.lid_emb = torch.nn.Embedding(langs, embed_dim) if spk_embed_dim is None: dec_idim = eunits elif spk_embed_integration_type == "concat": dec_idim = eunits + spk_embed_dim elif spk_embed_integration_type == "add": dec_idim = eunits self.projection = torch.nn.Linear(self.spk_embed_dim, eunits) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled " "in forward attention." ) self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled " "in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input batch (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, : text_lengths.max()] # for data-parallel feats = feats[:, : feats_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = feats olens = feats_lengths # make labels for stop prediction labels = make_pad_mask(olens - 1).to(ys.device, ys.dtype) labels = F.pad(labels, [0, 1], "constant", 1.0) # calculate tacotron2 outputs after_outs, before_outs, logits, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) # modify mod part of groundtruth if self.reduction_factor > 1: assert olens.ge( self.reduction_factor ).all(), "Output length must be greater than or equal to reduction factor." olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels = torch.scatter( labels, 1, (olens - 1).unsqueeze(1), 1.0 ) # see #3388 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens ) if self.loss_type == "L1+L2": loss = l1_loss + mse_loss + bce_loss elif self.loss_type == "L1": loss = l1_loss + bce_loss elif self.loss_type == "L2": loss = mse_loss + bce_loss else: raise ValueError(f"unknown --loss-type {self.loss_type}") stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable( (loss, stats, batch_size), loss.device ) return loss, stats, weight else: return loss, stats, after_outs def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor, olens: torch.Tensor, spembs: torch.Tensor, sids: torch.Tensor, lids: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: hs, hlens = self.enc(xs, ilens) if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) if self.spks > 0: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs > 0: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) return self.dec(hs, hlens, ys) def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T_text,). feats (Optional[Tensor]): Feature sequence to extract style (N, idim). spembs (Optional[Tensor]): Speaker embedding (spk_embed_dim,). sids (Optional[Tensor]): Speaker ID (1,). lids (Optional[Tensor]): Language ID (1,). threshold (float): Threshold in inference. minlenratio (float): Minimum length ratio in inference. maxlenratio (float): Maximum length ratio in inference. use_att_constraint (bool): Whether to apply attention constraint. backward_window (int): Backward window in attention constraint. forward_window (int): Forward window in attention constraint. use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). * prob (Tensor): Output sequence of stop probabilities (T_feats,). * att_w (Tensor): Attention weights (T_feats, T). """ x = text y = feats spemb = spembs # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # inference with teacher forcing if use_teacher_forcing: assert feats is not None, "feats must be provided with teacher forcing." xs, ys = x.unsqueeze(0), y.unsqueeze(0) spembs = None if spemb is None else spemb.unsqueeze(0) ilens = x.new_tensor([xs.size(1)]).long() olens = y.new_tensor([ys.size(1)]).long() outs, _, _, att_ws = self._forward( xs=xs, ilens=ilens, ys=ys, olens=olens, spembs=spembs, sids=sids, lids=lids, ) return dict(feat_gen=outs[0], att_w=att_ws[0]) # inference h = self.enc.inference(x) if self.use_gst: style_emb = self.gst(y.unsqueeze(0)) h = h + style_emb if self.spk_embed_dim is not None: hs, spembs = h.unsqueeze(0), spemb.unsqueeze(0) h = self._integrate_with_spk_embed(hs, spembs)[0] out, prob, att_w = self.dec.inference( h, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) return dict(feat_gen=out, prob=prob, att_w=att_w) def _integrate_with_spk_embed( self, hs: torch.Tensor, spembs: torch.Tensor ) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, eunits). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, eunits) if integration_type is "add" else (B, Tmax, eunits + spk_embed_dim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) else: raise NotImplementedError("support only add or concat.") return hs
class Tacotron2(TTSInterface, torch.nn.Module): """Tacotron2 module for end-to-end text-to-speech (E2E-TTS). This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument('--embed-dim', default=512, type=int, help='Number of dimension of embedding') group.add_argument('--elayers', default=1, type=int, help='Number of encoder layers') group.add_argument('--eunits', '-u', default=512, type=int, help='Number of encoder hidden units') group.add_argument('--econv-layers', default=3, type=int, help='Number of encoder convolution layers') group.add_argument('--econv-chans', default=512, type=int, help='Number of encoder convolution channels') group.add_argument('--econv-filts', default=5, type=int, help='Filter size of encoder convolution') # attention group.add_argument('--atype', default="location", type=str, choices=["forward_ta", "forward", "location"], help='Type of attention mechanism') group.add_argument('--adim', default=512, type=int, help='Number of attention transformation dimensions') group.add_argument('--aconv-chans', default=32, type=int, help='Number of attention convolution channels') group.add_argument('--aconv-filts', default=15, type=int, help='Filter size of attention convolution') group.add_argument('--cumulate-att-w', default=True, type=strtobool, help="Whether or not to cumulate attention weights") # decoder group.add_argument('--dlayers', default=2, type=int, help='Number of decoder layers') group.add_argument('--dunits', default=1024, type=int, help='Number of decoder hidden units') group.add_argument('--prenet-layers', default=2, type=int, help='Number of prenet layers') group.add_argument('--prenet-units', default=256, type=int, help='Number of prenet hidden units') group.add_argument('--postnet-layers', default=5, type=int, help='Number of postnet layers') group.add_argument('--postnet-chans', default=512, type=int, help='Number of postnet channels') group.add_argument('--postnet-filts', default=5, type=int, help='Filter size of postnet') group.add_argument('--output-activation', default=None, type=str, nargs='?', help='Output activation function') # cbhg group.add_argument('--use-cbhg', default=False, type=strtobool, help='Whether to use CBHG module') group.add_argument('--cbhg-conv-bank-layers', default=8, type=int, help='Number of convoluional bank layers in CBHG') group.add_argument('--cbhg-conv-bank-chans', default=128, type=int, help='Number of convoluional bank channles in CBHG') group.add_argument('--cbhg-conv-proj-filts', default=3, type=int, help='Filter size of convoluional projection layer in CBHG') group.add_argument('--cbhg-conv-proj-chans', default=256, type=int, help='Number of convoluional projection channels in CBHG') group.add_argument('--cbhg-highway-layers', default=4, type=int, help='Number of highway layers in CBHG') group.add_argument('--cbhg-highway-units', default=128, type=int, help='Number of highway units in CBHG') group.add_argument('--cbhg-gru-units', default=256, type=int, help='Number of GRU units in CBHG') # model (parameter) related group.add_argument('--use-batch-norm', default=True, type=strtobool, help='Whether to use batch normalization') group.add_argument('--use-concate', default=True, type=strtobool, help='Whether to concatenate encoder embedding with decoder outputs') group.add_argument('--use-residual', default=True, type=strtobool, help='Whether to use residual connection in conv layer') group.add_argument('--dropout-rate', default=0.5, type=float, help='Dropout rate') group.add_argument('--zoneout-rate', default=0.1, type=float, help='Zoneout rate') group.add_argument('--reduction-factor', default=1, type=int, help='Reduction factor') group.add_argument("--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions") group.add_argument("--spc-dim", default=None, type=int, help="Number of spectrogram dimensions") # loss related group.add_argument('--use-masking', default=False, type=strtobool, help='Whether to use masking in calculation of loss') group.add_argument('--bce-pos-weight', default=20.0, type=float, help='Positive sample weight in BCE calculation (only for use-masking=True)') group.add_argument("--use-guided-attn-loss", default=False, type=strtobool, help="Whether to use guided attention loss") group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float, help="Sigma in guided attention loss") group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float, help="Lambda in guided attention loss") return parser def __init__(self, idim, odim, args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - adim (int): The number of dimension of mlp in attention. - aconv_chans (int): The number of attention conv filter channels. - aconv_filts (int): The number of attention conv filter size. - cumulate_att_w (bool): Whether to cumulate previous attention weight. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True). - use_cbhg (bool): Whether to use CBHG module. - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG. - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG. - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG. - cbhg_proj_chans (int): The number of channels of projection layer in CBHG. - cbhg_highway_layers (int): The number of layers of highway network in CBHG. - cbhg_highway_units (int): The number of units of highway network in CBHG. - cbhg_gru_units (int): The number of units of GRU in CBHG. - use_masking (bool): Whether to mask padded part in loss calculation. - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). - use-guided-attn-loss (bool): Whether to use guided attention loss. - guided-attn-loss-sigma (float) Sigma in guided attention loss. - guided-attn-loss-lamdba (float): Lambda in guided attention loss. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.spk_embed_dim = args.spk_embed_dim self.cumulate_att_w = args.cumulate_att_w self.reduction_factor = args.reduction_factor self.use_cbhg = args.use_cbhg self.use_guided_attn_loss = args.use_guided_attn_loss # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError('there is no such an activation function. (%s)' % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder(idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx) dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim if args.atype == "location": att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) elif args.atype == "forward": att = AttForward(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled in forward attention.") self.cumulate_att_w = False elif args.atype == "forward_ta": att = AttForwardTA(dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder(idim=dec_idim, odim=odim, att=att, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor) self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=args.guided_attn_loss_sigma, alpha=args.guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG(idim=odim, odim=args.spc_dim, conv_bank_layers=args.cbhg_conv_bank_layers, conv_bank_chans=args.cbhg_conv_bank_chans, conv_proj_filts=args.cbhg_conv_proj_filts, conv_proj_chans=args.cbhg_conv_proj_chans, highway_layers=args.cbhg_highway_layers, highway_units=args.cbhg_highway_units, gru_units=args.cbhg_gru_units) self.cbhg_loss = CBHGLoss(use_masking=args.use_masking) def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). spcs (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] labels = labels[:, :max_out] # calculate tacotron2 outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new([olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # caluculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss( after_outs, before_outs, logits, ys, labels, olens) loss = l1_loss + mse_loss + bce_loss report_keys = [ {'l1_loss': l1_loss.item()}, {'mse_loss': mse_loss.item()}, {'bce_loss': bce_loss.item()}, ] # caluculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new([olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss report_keys += [ {'attn_loss': attn_loss.item()}, ] # caluculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != spcs.shape[1]: spcs = spcs[:, :max_out] # caluculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss report_keys += [ {'cbhg_l1_loss': cbhg_l1_loss.item()}, {'cbhg_mse_loss': cbhg_mse_loss.item()}, ] report_keys += [{'loss': loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): - threshold (float): Threshold in inference. - minlenratio (float): Minimum length ratio in inference. - maxlenratio (float): Maximum length ratio in inference. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ # get options threshold = inference_args.threshold minlenratio = inference_args.minlenratio maxlenratio = inference_args.maxlenratio # inference h = self.enc.inference(x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). Returns: numpy.ndarray: Batch of attention weights (B, Lmax, Tmax). """ # check ilens type (should be list of int) if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray): ilens = list(map(int, ilens)) self.eval() with torch.no_grad(): hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) att_ws = self.dec.calculate_all_attentions(hs, hlens, ys) self.train() return att_ws.cpu().numpy() @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss'] if self.use_guided_attn_loss: plot_keys += ['attn_loss'] if self.use_cbhg: plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss'] return plot_keys
class Tacotron2(AbsTTS): """Tacotron2 module for end-to-end text-to-speech. This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters into the sequence of Mel-filterbanks. .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`: https://arxiv.org/abs/1712.05884 Args: idim: Dimension of the inputs. odim: Dimension of the outputs. spk_embed_dim: Dimension of the speaker embedding. embed_dim: Dimension of character embedding. elayers: The number of encoder blstm layers. eunits: The number of encoder blstm units. econv_layers: The number of encoder conv layers. econv_filts: The number of encoder conv filter size. econv_chans: The number of encoder conv filter channels. dlayers: The number of decoder lstm layers. dunits: The number of decoder lstm units. prenet_layers: The number of prenet layers. prenet_units: The number of prenet units. postnet_layers: The number of postnet layers. postnet_filts: The number of postnet filter size. postnet_chans: The number of postnet filter channels. output_activation: The name of activation function for outputs. adim: The number of dimension of mlp in attention. aconv_chans: The number of attention conv filter channels. aconv_filts: The number of attention conv filter size. cumulate_att_w: Whether to cumulate previous attention weight. use_batch_norm: Whether to use batch normalization. use_concate: Whether to concatenate encoder embedding with decoder lstm outputs. dropout_rate: Dropout rate. zoneout_rate: Zoneout rate. reduction_factor: Reduction factor. spk_embed_dim: Number of speaker embedding dimenstions. spc_dim: Number of spectrogram embedding dimenstions (only for use_cbhg=True). use_cbhg: Whether to use CBHG module. cbhg_conv_bank_layers: The number of convoluional banks in CBHG. cbhg_conv_bank_chans: The number of channels of convolutional bank in CBHG. cbhg_proj_filts: The number of filter size of projection layeri in CBHG. cbhg_proj_chans: The number of channels of projection layer in CBHG. cbhg_highway_layers: The number of layers of highway network in CBHG. cbhg_highway_units: The number of units of highway network in CBHG. cbhg_gru_units: The number of units of GRU in CBHG. use_masking: Whether to mask padded part in loss calculation. use_weighted_masking: Whether to apply weighted masking in loss calculation. bce_pos_weight: Weight of positive sample of stop token (only for use_masking=True). use_guided_attn_loss: Whether to use guided attention loss. guided_attn_loss_sigma: Sigma in guided attention loss. guided_attn_loss_lamdba: Lambda in guided attention loss. """ def __init__( self, idim: int, odim: int, embed_dim: int = 512, elayers: int = 1, eunits: int = 512, econv_layers: int = 3, econv_chans: int = 512, econv_filts: int = 5, atype: str = "location", adim: int = 512, aconv_chans: int = 32, aconv_filts: int = 15, cumulate_att_w: bool = True, dlayers: int = 2, dunits: int = 1024, prenet_layers: int = 2, prenet_units: int = 256, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, output_activation: str = None, use_cbhg: bool = False, cbhg_conv_bank_layers: int = 8, cbhg_conv_bank_chans: int = 128, cbhg_conv_proj_filts: int = 3, cbhg_conv_proj_chans: int = 256, cbhg_highway_layers: int = 4, cbhg_highway_units: int = 128, cbhg_gru_units: int = 256, use_batch_norm: bool = True, use_concate: bool = True, use_residual: bool = False, dropout_rate: float = 0.5, zoneout_rate: float = 0.1, reduction_factor: int = 1, spk_embed_dim: int = None, spc_dim: int = None, use_masking: bool = True, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, use_guided_attn_loss: bool = True, guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor self.use_cbhg = use_cbhg self.use_guided_attn_loss = use_guided_attn_loss # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError(f"there is no such an activation function. " f"({output_activation})") # set padding idx padding_idx = 0 self.padding_idx = padding_idx # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) dec_idim = eunits if spk_embed_dim is None else eunits + spk_embed_dim if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) if self.use_cbhg: self.cbhg = CBHG( idim=odim, odim=spc_dim, conv_bank_layers=cbhg_conv_bank_layers, conv_bank_chans=cbhg_conv_bank_chans, conv_proj_filts=cbhg_conv_proj_filts, conv_proj_chans=cbhg_conv_proj_chans, highway_layers=cbhg_highway_layers, highway_units=cbhg_highway_units, gru_units=cbhg_gru_units, ) self.cbhg_loss = CBHGLoss(use_masking=use_masking) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, spembs: torch.Tensor = None, spcs: torch.Tensor = None, spcs_lengths: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text: Batch of padded character ids (B, Tmax). text_lengths: Batch of lengths of each input batch (B,). speech: Batch of padded target features (B, Lmax, odim). speech_lengths: Batch of the lengths of each target (B,). spembs: Batch of speaker embedding vectors (B, spk_embed_dim). spcs: Batch of ground-truth spectrogram (B, Lmax, spc_dim). spcs_lengths: """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", 0.0) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys = speech olens = speech_lengths # make labels for stop prediction labels = make_pad_mask(olens).to(ys.device, ys.dtype) # calculate tacotron2 outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] labels = labels[:, :max_out] labels[:, -1] = 1.0 # make sure at least one frame has 1 # calculate taco2 loss l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs, logits, ys, labels, olens) loss = l1_loss + mse_loss + bce_loss stats = dict( l1_loss=l1_loss.item(), mse_loss=mse_loss.item(), bce_loss=bce_loss.item(), ) # calculate attention loss if self.use_guided_attn_loss: # NOTE(kan-bayashi): length of output for auto-regressive # input will be changed when r > 1 if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens attn_loss = self.attn_loss(att_ws, ilens, olens_in) loss = loss + attn_loss stats.update(attn_loss=attn_loss.item()) # caluculate cbhg loss if self.use_cbhg: # remove unnecessary padded part (for multi-gpus) if max_out != spcs.shape[1]: spcs = spcs[:, :max_out] # caluculate cbhg outputs & loss and report them cbhg_outs, _ = self.cbhg(after_outs, olens) cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss( cbhg_outs, spcs, olens) loss = loss + cbhg_l1_loss + cbhg_mse_loss stats.update( cbhg_l1_loss=cbhg_l1_loss.item(), cbhg_mse_loss=cbhg_mse_loss.item(), ) stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def inference( self, text: torch.Tensor, spembs: torch.Tensor = None, threshold: float = 0.5, minlenratio: float = 0.0, maxlenratio: float = 10.0, use_att_constraint: bool = False, backward_window: int = 1, forward_window: int = 3, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text: Input sequence of characters (T,). spembs: Speaker embedding vector (spk_embed_dim,). threshold: Threshold in inference. minlenratio: Minimum length ratio in inference. maxlenratio: Maximum length ratio in inference. use_att_constraint: Whether to apply attention constraint. backward_window: Backward window in attention constraint. forward_window: Forward window in attention constraint. Returns: Tensor: Output sequence of features (L, odim). Tensor: Output sequence of stop probabilities (L,). Tensor: Attention weights (L, T). """ x = text spemb = spembs # inference h = self.enc.inference(x) if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) outs, probs, att_ws = self.dec.inference( h, threshold=threshold, minlenratio=minlenratio, maxlenratio=maxlenratio, use_att_constraint=use_att_constraint, backward_window=backward_window, forward_window=forward_window, ) if self.use_cbhg: cbhg_outs = self.cbhg.inference(outs) return cbhg_outs, probs, att_ws else: return outs, probs, att_ws