示例#1
0
    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int):
                    Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions
                    (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int):
                    The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int):
                    The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int):
                    The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int):
                    The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int):
                    The number of layers of highway network in CBHG.
                - cbhg_highway_units (int):
                    The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token
                    (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.adim = args.adim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.encoder_reduction_factor = args.encoder_reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss
        self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda
        self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError(
                "there is no such an activation function. (%s)" % args.output_activation
            )

        # define network modules
        self.enc = Encoder(
            idim=idim * args.encoder_reduction_factor,
            input_layer="linear",
            elayers=args.elayers,
            eunits=args.eunits,
            econv_layers=args.econv_layers,
            econv_chans=args.econv_chans,
            econv_filts=args.econv_filts,
            use_batch_norm=args.use_batch_norm,
            use_residual=args.use_residual,
            dropout_rate=args.dropout_rate,
        )
        dec_idim = (
            args.eunits
            if args.spk_embed_dim is None
            else args.eunits + args.spk_embed_dim
        )
        if args.atype == "location":
            att = AttLoc(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
        elif args.atype == "forward":
            att = AttForward(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(
                dec_idim,
                args.dunits,
                args.adim,
                args.aconv_chans,
                args.aconv_filts,
                odim,
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=args.dlayers,
            dunits=args.dunits,
            prenet_layers=args.prenet_layers,
            prenet_units=args.prenet_units,
            postnet_layers=args.postnet_layers,
            postnet_chans=args.postnet_chans,
            postnet_filts=args.postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=args.use_batch_norm,
            use_concate=args.use_concate,
            dropout_rate=args.dropout_rate,
            zoneout_rate=args.zoneout_rate,
            reduction_factor=args.reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=args.spc_dim,
                conv_bank_layers=args.cbhg_conv_bank_layers,
                conv_bank_chans=args.cbhg_conv_bank_chans,
                conv_proj_filts=args.cbhg_conv_proj_filts,
                conv_proj_chans=args.cbhg_conv_proj_chans,
                highway_layers=args.cbhg_highway_layers,
                highway_units=args.cbhg_highway_units,
                gru_units=args.cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
        if self.src_reconstruction_loss_lambda > 0:
            self.src_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.src_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, idim * args.encoder_reduction_factor
            )

            self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)
        if self.trg_reconstruction_loss_lambda > 0:
            self.trg_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.trg_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, odim * args.reduction_factor
            )
            self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)
示例#2
0
class Tacotron2(TTSInterface, torch.nn.Module):
    """VC Tacotron2 module for VC.

    This is a module of Tacotron2-based VC model,
    which convert the sequence of acoustic features
    into the sequence of acoustic features.
    """

    @staticmethod
    def add_arguments(parser):
        """Add model-specific arguments to the parser."""
        group = parser.add_argument_group("tacotron 2 model setting")
        # encoder
        group.add_argument(
            "--elayers", default=1, type=int, help="Number of encoder layers"
        )
        group.add_argument(
            "--eunits",
            "-u",
            default=512,
            type=int,
            help="Number of encoder hidden units",
        )
        group.add_argument(
            "--econv-layers",
            default=3,
            type=int,
            help="Number of encoder convolution layers",
        )
        group.add_argument(
            "--econv-chans",
            default=512,
            type=int,
            help="Number of encoder convolution channels",
        )
        group.add_argument(
            "--econv-filts",
            default=5,
            type=int,
            help="Filter size of encoder convolution",
        )
        # attention
        group.add_argument(
            "--atype",
            default="location",
            type=str,
            choices=["forward_ta", "forward", "location"],
            help="Type of attention mechanism",
        )
        group.add_argument(
            "--adim",
            default=512,
            type=int,
            help="Number of attention transformation dimensions",
        )
        group.add_argument(
            "--aconv-chans",
            default=32,
            type=int,
            help="Number of attention convolution channels",
        )
        group.add_argument(
            "--aconv-filts",
            default=15,
            type=int,
            help="Filter size of attention convolution",
        )
        group.add_argument(
            "--cumulate-att-w",
            default=True,
            type=strtobool,
            help="Whether or not to cumulate attention weights",
        )
        # decoder
        group.add_argument(
            "--dlayers", default=2, type=int, help="Number of decoder layers"
        )
        group.add_argument(
            "--dunits", default=1024, type=int, help="Number of decoder hidden units"
        )
        group.add_argument(
            "--prenet-layers", default=2, type=int, help="Number of prenet layers"
        )
        group.add_argument(
            "--prenet-units",
            default=256,
            type=int,
            help="Number of prenet hidden units",
        )
        group.add_argument(
            "--postnet-layers", default=5, type=int, help="Number of postnet layers"
        )
        group.add_argument(
            "--postnet-chans", default=512, type=int, help="Number of postnet channels"
        )
        group.add_argument(
            "--postnet-filts", default=5, type=int, help="Filter size of postnet"
        )
        group.add_argument(
            "--output-activation",
            default=None,
            type=str,
            nargs="?",
            help="Output activation function",
        )
        # cbhg
        group.add_argument(
            "--use-cbhg",
            default=False,
            type=strtobool,
            help="Whether to use CBHG module",
        )
        group.add_argument(
            "--cbhg-conv-bank-layers",
            default=8,
            type=int,
            help="Number of convoluional bank layers in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-bank-chans",
            default=128,
            type=int,
            help="Number of convoluional bank channles in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-proj-filts",
            default=3,
            type=int,
            help="Filter size of convoluional projection layer in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-proj-chans",
            default=256,
            type=int,
            help="Number of convoluional projection channels in CBHG",
        )
        group.add_argument(
            "--cbhg-highway-layers",
            default=4,
            type=int,
            help="Number of highway layers in CBHG",
        )
        group.add_argument(
            "--cbhg-highway-units",
            default=128,
            type=int,
            help="Number of highway units in CBHG",
        )
        group.add_argument(
            "--cbhg-gru-units",
            default=256,
            type=int,
            help="Number of GRU units in CBHG",
        )
        # model (parameter) related
        group.add_argument(
            "--use-batch-norm",
            default=True,
            type=strtobool,
            help="Whether to use batch normalization",
        )
        group.add_argument(
            "--use-concate",
            default=True,
            type=strtobool,
            help="Whether to concatenate encoder embedding with decoder outputs",
        )
        group.add_argument(
            "--use-residual",
            default=True,
            type=strtobool,
            help="Whether to use residual connection in conv layer",
        )
        group.add_argument(
            "--dropout-rate", default=0.5, type=float, help="Dropout rate"
        )
        group.add_argument(
            "--zoneout-rate", default=0.1, type=float, help="Zoneout rate"
        )
        group.add_argument(
            "--reduction-factor",
            default=1,
            type=int,
            help="Reduction factor (for decoder)",
        )
        group.add_argument(
            "--encoder-reduction-factor",
            default=1,
            type=int,
            help="Reduction factor (for encoder)",
        )
        group.add_argument(
            "--spk-embed-dim",
            default=None,
            type=int,
            help="Number of speaker embedding dimensions",
        )
        group.add_argument(
            "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions"
        )
        group.add_argument(
            "--pretrained-model", default=None, type=str, help="Pretrained model path"
        )
        # loss related
        group.add_argument(
            "--use-masking",
            default=False,
            type=strtobool,
            help="Whether to use masking in calculation of loss",
        )
        group.add_argument(
            "--bce-pos-weight",
            default=20.0,
            type=float,
            help="Positive sample weight in BCE calculation "
            "(only for use-masking=True)",
        )
        group.add_argument(
            "--use-guided-attn-loss",
            default=False,
            type=strtobool,
            help="Whether to use guided attention loss",
        )
        group.add_argument(
            "--guided-attn-loss-sigma",
            default=0.4,
            type=float,
            help="Sigma in guided attention loss",
        )
        group.add_argument(
            "--guided-attn-loss-lambda",
            default=1.0,
            type=float,
            help="Lambda in guided attention loss",
        )
        group.add_argument(
            "--src-reconstruction-loss-lambda",
            default=1.0,
            type=float,
            help="Lambda in source reconstruction loss",
        )
        group.add_argument(
            "--trg-reconstruction-loss-lambda",
            default=1.0,
            type=float,
            help="Lambda in target reconstruction loss",
        )
        return parser

    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int):
                    Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions
                    (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int):
                    The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int):
                    The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int):
                    The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int):
                    The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int):
                    The number of layers of highway network in CBHG.
                - cbhg_highway_units (int):
                    The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token
                    (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.adim = args.adim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.encoder_reduction_factor = args.encoder_reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss
        self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda
        self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError(
                "there is no such an activation function. (%s)" % args.output_activation
            )

        # define network modules
        self.enc = Encoder(
            idim=idim * args.encoder_reduction_factor,
            input_layer="linear",
            elayers=args.elayers,
            eunits=args.eunits,
            econv_layers=args.econv_layers,
            econv_chans=args.econv_chans,
            econv_filts=args.econv_filts,
            use_batch_norm=args.use_batch_norm,
            use_residual=args.use_residual,
            dropout_rate=args.dropout_rate,
        )
        dec_idim = (
            args.eunits
            if args.spk_embed_dim is None
            else args.eunits + args.spk_embed_dim
        )
        if args.atype == "location":
            att = AttLoc(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
        elif args.atype == "forward":
            att = AttForward(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(
                dec_idim,
                args.dunits,
                args.adim,
                args.aconv_chans,
                args.aconv_filts,
                odim,
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=args.dlayers,
            dunits=args.dunits,
            prenet_layers=args.prenet_layers,
            prenet_units=args.prenet_units,
            postnet_layers=args.postnet_layers,
            postnet_chans=args.postnet_chans,
            postnet_filts=args.postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=args.use_batch_norm,
            use_concate=args.use_concate,
            dropout_rate=args.dropout_rate,
            zoneout_rate=args.zoneout_rate,
            reduction_factor=args.reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=args.spc_dim,
                conv_bank_layers=args.cbhg_conv_bank_layers,
                conv_bank_chans=args.cbhg_conv_bank_chans,
                conv_proj_filts=args.cbhg_conv_proj_filts,
                conv_proj_chans=args.cbhg_conv_proj_chans,
                highway_layers=args.cbhg_highway_layers,
                highway_units=args.cbhg_highway_units,
                gru_units=args.cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
        if self.src_reconstruction_loss_lambda > 0:
            self.src_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.src_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, idim * args.encoder_reduction_factor
            )

            self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)
        if self.trg_reconstruction_loss_lambda > 0:
            self.trg_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.trg_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, odim * args.reduction_factor
            )
            self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)

    def forward(
        self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs
    ):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded acoustic features (B, Tmax, idim).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            spcs (Tensor, optional):
                Batch of groundtruth spectrograms (B, Lmax, spc_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # thin out input frames for reduction factor
        # (B, Lmax, idim) ->  (B, Lmax // r, idim * r)
        if self.encoder_reduction_factor > 1:
            B, Lmax, idim = xs.shape
            if Lmax % self.encoder_reduction_factor != 0:
                xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :]
            xs_ds = xs.contiguous().view(
                B,
                int(Lmax / self.encoder_reduction_factor),
                idim * self.encoder_reduction_factor,
            )
            ilens_ds = ilens.new(
                [ilen // self.encoder_reduction_factor for ilen in ilens]
            )
        else:
            xs_ds, ilens_ds = xs, ilens

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs_ds, ilens_ds)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # calculate src reconstruction
        if self.src_reconstruction_loss_lambda > 0:
            B, _in_length, _adim = hs.shape
            xt, xtlens = self.src_reconstructor(hs, hlens)
            xt = self.src_reconstructor_linear(xt)
            if self.encoder_reduction_factor > 1:
                xt = xt.view(B, -1, self.idim)

        # calculate trg reconstruction
        if self.trg_reconstruction_loss_lambda > 0:
            olens_trg_cp = olens.new(
                sorted([olen // self.reduction_factor for olen in olens], reverse=True)
            )
            B, _in_length, _adim = hs.shape
            _, _out_length, _ = att_ws.shape
            # att_R should be [B, out_length / r_d, adim]
            att_R = torch.sum(
                hs.view(B, 1, _in_length, _adim)
                * att_ws.view(B, _out_length, _in_length, 1),
                dim=2,
            )
            yt, ytlens = self.trg_reconstructor(
                att_R, olens_trg_cp
            )  # is using olens correct?
            yt = self.trg_reconstructor_linear(yt)
            if self.reduction_factor > 1:
                yt = yt.view(
                    B, -1, self.odim
                )  # now att_R should be [B, out_length, adim]

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            assert olens.ge(
                self.reduction_factor
            ).all(), "Output length must be greater than or equal to reduction factor."
            olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels = torch.scatter(
                labels, 1, (olens - 1).unsqueeze(1), 1.0
            )  # see #3388
        if self.encoder_reduction_factor > 1:
            ilens = ilens.new(
                [ilen - ilen % self.encoder_reduction_factor for ilen in ilens]
            )
            max_in = max(ilens)
            xs = xs[:, :max_in]

        # calculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(
            after_outs, before_outs, logits, ys, labels, olens
        )
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {"l1_loss": l1_loss.item()},
            {"mse_loss": mse_loss.item()},
            {"bce_loss": bce_loss.item()},
        ]

        # calculate context_preservation loss
        if self.src_reconstruction_loss_lambda > 0:
            src_recon_l1_loss, src_recon_mse_loss = self.src_reconstruction_loss(
                xt, xs, ilens
            )
            loss = loss + src_recon_l1_loss
            report_keys += [
                {"src_recon_l1_loss": src_recon_l1_loss.item()},
                {"src_recon_mse_loss": src_recon_mse_loss.item()},
            ]
        if self.trg_reconstruction_loss_lambda > 0:
            trg_recon_l1_loss, trg_recon_mse_loss = self.trg_reconstruction_loss(
                yt, ys, olens
            )
            loss = loss + trg_recon_l1_loss
            report_keys += [
                {"trg_recon_l1_loss": trg_recon_l1_loss.item()},
                {"trg_recon_mse_loss": trg_recon_mse_loss.item()},
            ]

        # calculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive input
            #   will be changed when r > 1
            if self.encoder_reduction_factor > 1:
                ilens_in = ilens.new(
                    [ilen // self.encoder_reduction_factor for ilen in ilens]
                )
            else:
                ilens_in = ilens
            if self.reduction_factor > 1:
                olens_in = olens.new([olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens_in, olens_in)
            loss = loss + attn_loss
            report_keys += [
                {"attn_loss": attn_loss.item()},
            ]

        # calculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != spcs.shape[1]:
                spcs = spcs[:, :max_out]

            # calculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {"cbhg_l1_loss": cbhg_l1_loss.item()},
                {"cbhg_mse_loss": cbhg_mse_loss.item()},
            ]

        report_keys += [{"loss": loss.item()}]
        self.reporter.report(report_keys)

        return loss

    def inference(self, x, inference_args, spemb=None, *args, **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Args:
            x (Tensor): Input sequence of acoustic features (T, idim).
            inference_args (Namespace):
                - threshold (float): Threshold in inference.
                - minlenratio (float): Minimum length ratio in inference.
                - maxlenratio (float): Maximum length ratio in inference.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).

        Returns:
            Tensor: Output sequence of features (L, odim).
            Tensor: Output sequence of stop probabilities (L,).
            Tensor: Attention weights (L, T).

        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio

        # thin out input frames for reduction factor
        # (B, Lmax, idim) ->  (B, Lmax // r, idim * r)
        if self.encoder_reduction_factor > 1:
            Lmax, idim = x.shape
            if Lmax % self.encoder_reduction_factor != 0:
                x = x[: -(Lmax % self.encoder_reduction_factor), :]
            x_ds = x.contiguous().view(
                int(Lmax / self.encoder_reduction_factor),
                idim * self.encoder_reduction_factor,
            )
        else:
            x_ds = x

        # inference
        h = self.enc.inference(x_ds)
        if self.spk_embed_dim is not None:
            spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1)
            h = torch.cat([h, spemb], dim=-1)
        outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio)

        if self.use_cbhg:
            cbhg_outs = self.cbhg.inference(outs)
            return cbhg_outs, probs, att_ws
        else:
            return outs, probs, att_ws

    def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs):
        """Calculate all of the attention weights.

        Args:
            xs (Tensor): Batch of padded acoustic features (B, Tmax, idim).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).

        """
        # check ilens type (should be list of int)
        if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
            ilens = list(map(int, ilens))

        self.eval()
        with torch.no_grad():
            # thin out input frames for reduction factor
            # (B, Lmax, idim) ->  (B, Lmax // r, idim * r)
            if self.encoder_reduction_factor > 1:
                B, Lmax, idim = xs.shape
                if Lmax % self.encoder_reduction_factor != 0:
                    xs = xs[:, : -(Lmax % self.encoder_reduction_factor), :]
                xs_ds = xs.contiguous().view(
                    B,
                    int(Lmax / self.encoder_reduction_factor),
                    idim * self.encoder_reduction_factor,
                )
                ilens_ds = [ilen // self.encoder_reduction_factor for ilen in ilens]
            else:
                xs_ds, ilens_ds = xs, ilens

            hs, hlens = self.enc(xs_ds, ilens_ds)
            if self.spk_embed_dim is not None:
                spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
                hs = torch.cat([hs, spembs], dim=-1)
            att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
        self.train()

        return att_ws.cpu().numpy()

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training.

        keys should match what `chainer.reporter` reports.
        If you add the key `loss`, the reporter will report `main/loss`
            and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss`
            and `validation/main/loss` values.

        Returns:
            list: List of strings which are base keys to plot during training.

        """
        plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"]
        if self.use_guided_attn_loss:
            plot_keys += ["attn_loss"]
        if self.use_cbhg:
            plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"]
        if self.src_reconstruction_loss_lambda > 0:
            plot_keys += ["src_recon_l1_loss", "src_recon_mse_loss"]
        if self.trg_reconstruction_loss_lambda > 0:
            plot_keys += ["trg_recon_l1_loss", "trg_recon_mse_loss"]
        return plot_keys

    def _sort_by_length(self, xs, ilens):
        sort_ilens, sort_idx = ilens.sort(0, descending=True)
        return xs[sort_idx], ilens[sort_idx], sort_idx

    def _revert_sort_by_length(self, xs, ilens, sort_idx):
        _, revert_idx = sort_idx.sort(0)
        return xs[revert_idx], ilens[revert_idx]
    def __init__(self, idim, odim, args):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = getattr(args, "use_guided_attn_loss",
                                            False)

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim, args.dunits, args.adim,
                               args.aconv_chans, args.aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(args)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma)
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(args)
示例#4
0
class Tacotron2(TTSInterface, torch.nn.Module):
    """Tacotron2 module for end-to-end text-to-speech (E2E-TTS).

    This is a module of Spectrogram prediction network in Tacotron2 described
    in `Natural TTS Synthesis
    by Conditioning WaveNet on Mel Spectrogram Predictions`_,
    which converts the sequence of characters
    into the sequence of Mel-filterbanks.

    .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
       https://arxiv.org/abs/1712.05884

    """
    @staticmethod
    def add_arguments(parser):
        """Add model-specific arguments to the parser."""
        group = parser.add_argument_group("tacotron 2 model setting")
        # encoder
        group.add_argument(
            "--embed-dim",
            default=512,
            type=int,
            help="Number of dimension of embedding",
        )
        group.add_argument("--elayers",
                           default=1,
                           type=int,
                           help="Number of encoder layers")
        group.add_argument(
            "--eunits",
            "-u",
            default=512,
            type=int,
            help="Number of encoder hidden units",
        )
        group.add_argument(
            "--econv-layers",
            default=3,
            type=int,
            help="Number of encoder convolution layers",
        )
        group.add_argument(
            "--econv-chans",
            default=512,
            type=int,
            help="Number of encoder convolution channels",
        )
        group.add_argument(
            "--econv-filts",
            default=5,
            type=int,
            help="Filter size of encoder convolution",
        )
        # attention
        group.add_argument(
            "--atype",
            default="location",
            type=str,
            choices=["forward_ta", "forward", "location"],
            help="Type of attention mechanism",
        )
        group.add_argument(
            "--adim",
            default=512,
            type=int,
            help="Number of attention transformation dimensions",
        )
        group.add_argument(
            "--aconv-chans",
            default=32,
            type=int,
            help="Number of attention convolution channels",
        )
        group.add_argument(
            "--aconv-filts",
            default=15,
            type=int,
            help="Filter size of attention convolution",
        )
        group.add_argument(
            "--cumulate-att-w",
            default=True,
            type=strtobool,
            help="Whether or not to cumulate attention weights",
        )
        # decoder
        group.add_argument("--dlayers",
                           default=2,
                           type=int,
                           help="Number of decoder layers")
        group.add_argument("--dunits",
                           default=1024,
                           type=int,
                           help="Number of decoder hidden units")
        group.add_argument("--prenet-layers",
                           default=2,
                           type=int,
                           help="Number of prenet layers")
        group.add_argument(
            "--prenet-units",
            default=256,
            type=int,
            help="Number of prenet hidden units",
        )
        group.add_argument("--postnet-layers",
                           default=5,
                           type=int,
                           help="Number of postnet layers")
        group.add_argument("--postnet-chans",
                           default=512,
                           type=int,
                           help="Number of postnet channels")
        group.add_argument("--postnet-filts",
                           default=5,
                           type=int,
                           help="Filter size of postnet")
        group.add_argument(
            "--output-activation",
            default=None,
            type=str,
            nargs="?",
            help="Output activation function",
        )
        # cbhg
        group.add_argument(
            "--use-cbhg",
            default=False,
            type=strtobool,
            help="Whether to use CBHG module",
        )
        group.add_argument(
            "--cbhg-conv-bank-layers",
            default=8,
            type=int,
            help="Number of convoluional bank layers in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-bank-chans",
            default=128,
            type=int,
            help="Number of convoluional bank channles in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-proj-filts",
            default=3,
            type=int,
            help="Filter size of convoluional projection layer in CBHG",
        )
        group.add_argument(
            "--cbhg-conv-proj-chans",
            default=256,
            type=int,
            help="Number of convoluional projection channels in CBHG",
        )
        group.add_argument(
            "--cbhg-highway-layers",
            default=4,
            type=int,
            help="Number of highway layers in CBHG",
        )
        group.add_argument(
            "--cbhg-highway-units",
            default=128,
            type=int,
            help="Number of highway units in CBHG",
        )
        group.add_argument(
            "--cbhg-gru-units",
            default=256,
            type=int,
            help="Number of GRU units in CBHG",
        )
        # model (parameter) related
        group.add_argument(
            "--use-batch-norm",
            default=True,
            type=strtobool,
            help="Whether to use batch normalization",
        )
        group.add_argument(
            "--use-concate",
            default=True,
            type=strtobool,
            help=
            "Whether to concatenate encoder embedding with decoder outputs",
        )
        group.add_argument(
            "--use-residual",
            default=True,
            type=strtobool,
            help="Whether to use residual connection in conv layer",
        )
        group.add_argument("--dropout-rate",
                           default=0.5,
                           type=float,
                           help="Dropout rate")
        group.add_argument("--zoneout-rate",
                           default=0.1,
                           type=float,
                           help="Zoneout rate")
        group.add_argument("--reduction-factor",
                           default=1,
                           type=int,
                           help="Reduction factor")
        group.add_argument(
            "--spk-embed-dim",
            default=None,
            type=int,
            help="Number of speaker embedding dimensions",
        )
        group.add_argument(
            "--char-embed-dim",
            default=None,
            type=int,
            help="Number of character embedding dimensions",
        )
        group.add_argument("--spc-dim",
                           default=None,
                           type=int,
                           help="Number of spectrogram dimensions")
        group.add_argument("--pretrained-model",
                           default=None,
                           type=str,
                           help="Pretrained model path")
        # loss related
        group.add_argument(
            "--use-masking",
            default=False,
            type=strtobool,
            help="Whether to use masking in calculation of loss",
        )
        group.add_argument(
            "--use-weighted-masking",
            default=False,
            type=strtobool,
            help="Whether to use weighted masking in calculation of loss",
        )
        group.add_argument(
            "--bce-pos-weight",
            default=20.0,
            type=float,
            help="Positive sample weight in BCE calculation "
            "(only for use-masking=True)",
        )
        group.add_argument(
            "--use-guided-attn-loss",
            default=False,
            type=strtobool,
            help="Whether to use guided attention loss",
        )
        group.add_argument(
            "--guided-attn-loss-sigma",
            default=0.4,
            type=float,
            help="Sigma in guided attention loss",
        )
        group.add_argument(
            "--guided-attn-loss-lambda",
            default=1.0,
            type=float,
            help="Lambda in guided attention loss",
        )
        return parser

    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - embed_dim (int): Dimension of character embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int): Whether to concatenate encoder embedding
                    with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions
                    (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int): The number of channels of
                    convolutional bank in CBHG.
                - cbhg_proj_filts (int):
                    The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int):
                    The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int):
                    The number of layers of highway network in CBHG.
                - cbhg_highway_units (int):
                    The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool):
                    Whether to apply masking for padded part in loss calculation.
                - use_weighted_masking (bool):
                    Whether to apply weighted masking in loss calculation.
                - bce_pos_weight (float):
                    Weight of positive sample of stop token (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.char_embed_dim = args.char_embed_dim
        self.into_embed_dim = args.into_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss
        self.use_intotype_loss = args.use_intotype_loss

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError("there is no such an activation function. (%s)" %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        enc_extra_dim = 0
        if args.char_embed_dim is not None and args.character_embedding_position in [
                'encoder', 'both'
        ]:
            enc_extra_dim = args.eunits
        self.enc = Encoder(
            idim=idim,
            embed_dim=args.embed_dim,
            elayers=args.elayers,
            eunits=args.eunits,
            econv_layers=args.econv_layers,
            econv_chans=args.econv_chans,
            econv_filts=args.econv_filts,
            use_batch_norm=args.use_batch_norm,
            use_residual=args.use_residual,
            dropout_rate=args.dropout_rate,
            padding_idx=padding_idx,
            extra_dim=enc_extra_dim,
        )
        self.pre_enc = None
        self.ch_enc = None

        chenc_type = CharacterEncoder
        chenc_odim = args.eunits
        if args.character_encoder_type == 'transformer':
            chenc_type = SentenceEncoder
            chenc_odim = 256

        if args.char_embed_dim is not None:
            if args.character_embedding_position == 'encoder':
                self.pre_enc = chenc_type(
                    idim=args.char_embed_dim,
                    pred_into_type=args.use_intotype_loss,
                    into_type_num=args.into_type_num,
                    reduce_character_embedding=args.reduce_character_embedding,
                    elayers=args.elayers,
                    eunits=args.eunits,
                )
            elif args.character_embedding_position == 'decoder':
                self.ch_enc = chenc_type(
                    idim=args.char_embed_dim,
                    pred_into_type=args.use_intotype_loss,
                    into_type_num=args.into_type_num,
                    reduce_character_embedding=args.reduce_character_embedding,
                    elayers=args.elayers,
                    eunits=args.eunits,
                )
            elif args.character_embedding_position == 'both':
                self.pre_enc = chenc_type(
                    idim=args.char_embed_dim,
                    pred_into_type=args.use_intotype_loss,
                    into_type_num=args.into_type_num,
                    reduce_character_embedding=args.reduce_character_embedding,
                    elayers=args.elayers,
                    eunits=args.eunits,
                )
                self.ch_enc = chenc_type(
                    idim=args.char_embed_dim,
                    pred_into_type=False,
                    into_type_num=0,
                    reduce_character_embedding=False,
                    elayers=args.elayers,
                    eunits=args.eunits,
                )
            else:
                raise ValueError(
                    "Invalid character embedding position \"%s\"" %
                    args.character_embedding_position)
        if args.into_embed_dim is not None:
            self.into_embed = self.embed = torch.nn.Embedding(
                args.into_type_num,
                args.into_embed_dim,
                padding_idx=padding_idx,
            )

        dec_idim = args.eunits
        if args.spk_embed_dim:
            dec_idim += args.spk_embed_dim
        if self.ch_enc is not None:
            dec_idim += chenc_odim
        if args.into_embed_dim:
            dec_idim += args.into_embed_dim

        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(
                dec_idim,
                args.dunits,
                args.adim,
                args.aconv_chans,
                args.aconv_filts,
                odim,
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=args.dlayers,
            dunits=args.dunits,
            prenet_layers=args.prenet_layers,
            prenet_units=args.prenet_units,
            postnet_layers=args.postnet_layers,
            postnet_chans=args.postnet_chans,
            postnet_filts=args.postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=args.use_batch_norm,
            use_concate=args.use_concate,
            dropout_rate=args.dropout_rate,
            zoneout_rate=args.zoneout_rate,
            reduction_factor=args.reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
            bce_pos_weight=args.bce_pos_weight,
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_intotype_loss:
            self.intotype_loss = IntoTypeLoss(args.into_type_num, )

        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=args.spc_dim,
                conv_bank_layers=args.cbhg_conv_bank_layers,
                conv_bank_chans=args.cbhg_conv_bank_chans,
                conv_proj_filts=args.cbhg_conv_proj_filts,
                conv_proj_chans=args.cbhg_conv_proj_chans,
                highway_layers=args.cbhg_highway_layers,
                highway_units=args.cbhg_highway_units,
                gru_units=args.cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)

    def expand_to(self, xs, lens):
        """
            xs: (B, D)
            lens: (B,)
        """
        # (B, T, 1)
        mask = to_device(xs, make_pad_mask(lens).unsqueeze(-1))
        # (B, D) -> (B, 1, D) -> (B, T, D)
        xs = xs.unsqueeze(1).expand(-1, mask.size(1),
                                    -1).masked_fill(mask, 0.0)
        return xs

    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                chembs=None,
                chlens=None,
                intotypes=None,
                spembs=None,
                extras=None,
                *args,
                **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            extras (Tensor, optional):
                Batch of groundtruth spectrograms (B, Lmax, spc_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # calculate tacotron2 outputs
        pre_xs = None
        if self.pre_enc is not None:
            pre_xs, _, pre_type_logits = self.pre_enc(chembs, chlens)
            if pre_xs.ndim != hs.ndim:
                pre_xs = self.expand_to(pre_xs, ilens)

        hs, hlens = self.enc(xs, ilens, pre_xs)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)

        if self.ch_enc is not None:
            ch_hs, _, ch_type_logits = self.ch_enc(chembs, chlens)
            if ch_hs.ndim != hs.ndim:
                ch_hs = self.expand_to(ch_hs, ilens)
            hs = torch.cat([hs, ch_hs], dim=-1)

        if self.into_embed_dim is not None:
            itembs = self.into_embed(intotypes).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, itembs], dim=-1)

        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # caluculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {
                "l1_loss": l1_loss.item()
            },
            {
                "mse_loss": mse_loss.item()
            },
            {
                "bce_loss": bce_loss.item()
            },
        ]

        # caluculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi):
            # length of output for auto-regressive input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            report_keys += [
                {
                    "attn_loss": attn_loss.item()
                },
            ]

        if self.use_intotype_loss:
            type_logits = pre_type_logits if self.pre_enc is not None else ch_type_logits
            it_loss = self.intotype_loss(type_logits, intotypes)
            loss = loss + it_loss
            report_keys += [{"intonation_type_loss": it_loss.item()}]

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != extras.shape[1]:
                extras = extras[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(
                cbhg_outs, extras, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {
                    "cbhg_l1_loss": cbhg_l1_loss.item()
                },
                {
                    "cbhg_mse_loss": cbhg_mse_loss.item()
                },
            ]

        report_keys += [{"loss": loss.item()}]
        self.reporter.report(report_keys)

        return loss

    def inference(self,
                  x,
                  inference_args,
                  chemb=None,
                  intotype=None,
                  spemb=None,
                  *args,
                  **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace):
                - threshold (float): Threshold in inference.
                - minlenratio (float): Minimum length ratio in inference.
                - maxlenratio (float): Maximum length ratio in inference.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).

        Returns:
            Tensor: Output sequence of features (L, odim).
            Tensor: Output sequence of stop probabilities (L,).
            Tensor: Attention weights (L, T).

        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio
        use_att_constraint = getattr(inference_args, "use_att_constraint",
                                     False)  # keep compatibility
        backward_window = inference_args.backward_window if use_att_constraint else 0
        forward_window = inference_args.forward_window if use_att_constraint else 0

        # inference
        pre_x = None
        if self.pre_enc is not None:
            pre_x, pre_type_logit = self.pre_enc.inference(chemb)
            if pre_x.ndim != x.ndim:
                pre_x = self.expand_to(pre_x.unsqueeze(0),
                                       torch.tensor([x.size(0)])).squeeze(0)
            # To print prediction of intonation types
            if pre_type_logit is not None:
                pre_type_logit = pre_type_logit.data.cpu().numpy()
                print(pre_type_logit)
                print(pre_type_logit.argmax())
            # ============
        h = self.enc.inference(x, pre_x)
        if self.spk_embed_dim is not None:
            spemb = F.normalize(spemb,
                                dim=0).unsqueeze(0).expand(h.size(0), -1)
            h = torch.cat([h, spemb], dim=-1)
        if self.ch_enc is not None:
            ch_h, ch_type_logit = self.ch_enc.inference(chemb)
            if ch_h.ndim != h.ndim:
                ch_h = self.expand_to(ch_h.unsqueeze(0),
                                      torch.tensor([x.size(0)])).squeeze(0)
            # To print prediction of intonation types
            if ch_type_logit is not None:
                ch_type_logit = ch_type_logit.data.cpu().numpy()
                print(ch_type_logit)
                print(ch_type_logit.argmax())
            # ============
            h = torch.cat([h, ch_h], dim=-1)
        if self.into_embed_dim:
            itemb = self.into_embed(intotype).unsqueeze(0).expand(
                h.size(0), -1)
            h = torch.cat([h, itemb], dim=-1)
        outs, probs, att_ws = self.dec.inference(
            h,
            threshold,
            minlenratio,
            maxlenratio,
            use_att_constraint=use_att_constraint,
            backward_window=backward_window,
            forward_window=forward_window,
        )

        if self.use_cbhg:
            outs = self.cbhg.inference(outs)
        return outs, probs, att_ws

    def calculate_all_attentions(self,
                                 xs,
                                 ilens,
                                 ys,
                                 chembs=None,
                                 chlens=None,
                                 intotypes=None,
                                 spembs=None,
                                 keep_tensor=False,
                                 *args,
                                 **kwargs):
        """Calculate all of the attention weights.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional):
                Batch of speaker embedding vectors (B, spk_embed_dim).
            keep_tensor (bool, optional): Whether to keep original tensor.

        Returns:
            Union[ndarray, Tensor]: Batch of attention weights (B, Lmax, Tmax).

        """
        # check ilens type (should be list of int)
        if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
            ilens = list(map(int, ilens))

        self.eval()
        with torch.no_grad():
            pre_xs = None
            if self.pre_enc is not None:
                pre_xs, _, pre_type_logits = self.pre_enc(chembs, chlens)
                if pre_xs.ndim != hs.ndim:
                    pre_xs = self.expand_to(pre_xs, ilens)
            hs, hlens = self.enc(xs, ilens, pre_xs)
            if self.spk_embed_dim is not None:
                spembs = F.normalize(spembs).unsqueeze(1).expand(
                    -1, hs.size(1), -1)
                hs = torch.cat([hs, spembs], dim=-1)
            if self.ch_enc is not None:
                ch_hs, _, ch_type_logits = self.ch_enc(chembs, chlens)
                if ch_hs.ndim != hs.ndim:
                    ch_hs = self.expand_to(ch_hs, ilens)
                hs = torch.cat([hs, ch_hs], dim=-1)
            if self.into_embed_dim is not None:
                itembs = self.into_embed(intotypes).unsqueeze(1).expand(
                    -1, hs.size(1), -1)
                hs = torch.cat([hs, itembs], dim=-1)
            att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
        self.train()

        if keep_tensor:
            return att_ws
        else:
            return att_ws.cpu().numpy()

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training.

        keys should match what `chainer.reporter` reports.
        If you add the key `loss`, the reporter will report `main/loss`
        and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss`
        and `validation/main/loss` values.

        Returns:
            list: List of strings which are base keys to plot during training.

        """
        plot_keys = ["loss", "l1_loss", "mse_loss", "bce_loss"]
        if self.use_guided_attn_loss:
            plot_keys += ["attn_loss"]
        if self.use_cbhg:
            plot_keys += ["cbhg_l1_loss", "cbhg_mse_loss"]
        return plot_keys

    def gta_inference(self,
                      xs,
                      ilens,
                      ys,
                      labels,
                      olens,
                      chembs=None,
                      intotypes=None,
                      spembs=None,
                      extras=None,
                      *args,
                      **kwargs):
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # calculate tacotron2 outputs
        pre_xs = None
        if self.pre_enc is not None:
            pre_xs, _, pre_type_logits = self.pre_enc(chembs, ilens)

        hs, hlens = self.enc(xs, ilens, pre_xs)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)

        if self.ch_enc is not None:
            ch_hs, _, ch_type_logits = self.ch_enc(chembs, ilens)
            hs = torch.cat([hs, ch_hs], dim=-1)

        if self.into_embed_dim is not None:
            itembs = self.into_embed(intotypes).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, itembs], dim=-1)

        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        return after_outs
class Tacotron2(TTSInterface, torch.nn.Module):
    """Tacotron2 based Seq2Seq converts chars to features

    Reference:
       Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions
       (https://arxiv.org/abs/1712.05884)

    :param int idim: dimension of the inputs
    :param int odim: dimension of the outputs
    :param Namespace args: argments containing following attributes
        (int) spk_embed_dim: dimension of the speaker embedding
        (int) embed_dim: dimension of character embedding
        (int) elayers: the number of encoder blstm layers
        (int) eunits: the number of encoder blstm units
        (int) econv_layers: the number of encoder conv layers
        (int) econv_filts: the number of encoder conv filter size
        (int) econv_chans: the number of encoder conv filter channels
        (int) dlayers: the number of decoder lstm layers
        (int) dunits: the number of decoder lstm units
        (int) prenet_layers: the number of prenet layers
        (int) prenet_units: the number of prenet units
        (int) postnet_layers: the number of postnet layers
        (int) postnet_filts: the number of postnet filter size
        (int) postnet_chans: the number of postnet filter channels
        (str) output_activation: the name of activation function for outputs
        (int) adim: the number of dimension of mlp in attention
        (int) aconv_chans: the number of attention conv filter channels
        (int) aconv_filts: the number of attention conv filter size
        (bool) cumulate_att_w: whether to cumulate previous attention weight
        (bool) use_batch_norm: whether to use batch normalization
        (bool) use_concate: whether to concatenate encoder embedding with decoder lstm outputs
        (float) dropout_rate: dropout rate
        (float) zoneout_rate: zoneout rate
        (int) reduction_factor: reduction factor
        (bool) use_cbhg: whether to use CBHG module
        (int) cbhg_conv_bank_layers: the number of convoluional banks in CBHG
        (int) cbhg_conv_bank_chans: the number of channels of convolutional bank in CBHG
        (int) cbhg_proj_filts: the number of filter size of projection layeri in CBHG
        (int) cbhg_proj_chans: the number of channels of projection layer in CBHG
        (int) cbhg_highway_layers: the number of layers of highway network in CBHG
        (int) cbhg_highway_units: the number of units of highway network in CBHG
        (int) cbhg_gru_units: the number of units of GRU in CBHG
        (bool) use_masking: whether to mask padded part in loss calculation
        (float) bce_pos_weight: weight of positive sample of stop token (only for use_masking=True)
    """
    @staticmethod
    def add_arguments(parser):
        # encoder
        parser.add_argument('--embed-dim',
                            default=512,
                            type=int,
                            help='Number of dimension of embedding')
        parser.add_argument('--elayers',
                            default=1,
                            type=int,
                            help='Number of encoder layers')
        parser.add_argument('--eunits',
                            '-u',
                            default=512,
                            type=int,
                            help='Number of encoder hidden units')
        parser.add_argument('--econv-layers',
                            default=3,
                            type=int,
                            help='Number of encoder convolution layers')
        parser.add_argument('--econv-chans',
                            default=512,
                            type=int,
                            help='Number of encoder convolution channels')
        parser.add_argument('--econv-filts',
                            default=5,
                            type=int,
                            help='Filter size of encoder convolution')
        # attention
        parser.add_argument('--atype',
                            default="location",
                            type=str,
                            choices=["forward_ta", "forward", "location"],
                            help='Type of attention mechanism')
        parser.add_argument(
            '--adim',
            default=512,
            type=int,
            help='Number of attention transformation dimensions')
        parser.add_argument('--aconv-chans',
                            default=32,
                            type=int,
                            help='Number of attention convolution channels')
        parser.add_argument('--aconv-filts',
                            default=15,
                            type=int,
                            help='Filter size of attention convolution')
        parser.add_argument(
            '--cumulate-att-w',
            default=True,
            type=strtobool,
            help="Whether or not to cumulate attention weights")
        # decoder
        parser.add_argument('--dlayers',
                            default=2,
                            type=int,
                            help='Number of decoder layers')
        parser.add_argument('--dunits',
                            default=1024,
                            type=int,
                            help='Number of decoder hidden units')
        parser.add_argument('--prenet-layers',
                            default=2,
                            type=int,
                            help='Number of prenet layers')
        parser.add_argument('--prenet-units',
                            default=256,
                            type=int,
                            help='Number of prenet hidden units')
        parser.add_argument('--postnet-layers',
                            default=5,
                            type=int,
                            help='Number of postnet layers')
        parser.add_argument('--postnet-chans',
                            default=512,
                            type=int,
                            help='Number of postnet channels')
        parser.add_argument('--postnet-filts',
                            default=5,
                            type=int,
                            help='Filter size of postnet')
        parser.add_argument('--output-activation',
                            default=None,
                            type=str,
                            nargs='?',
                            help='Output activation function')
        # cbhg
        parser.add_argument('--use-cbhg',
                            default=False,
                            type=strtobool,
                            help='Whether to use CBHG module')
        parser.add_argument('--cbhg-conv-bank-layers',
                            default=8,
                            type=int,
                            help='Number of convoluional bank layers in CBHG')
        parser.add_argument(
            '--cbhg-conv-bank-chans',
            default=128,
            type=int,
            help='Number of convoluional bank channles in CBHG')
        parser.add_argument(
            '--cbhg-conv-proj-filts',
            default=3,
            type=int,
            help='Filter size of convoluional projection layer in CBHG')
        parser.add_argument(
            '--cbhg-conv-proj-chans',
            default=256,
            type=int,
            help='Number of convoluional projection channels in CBHG')
        parser.add_argument('--cbhg-highway-layers',
                            default=4,
                            type=int,
                            help='Number of highway layers in CBHG')
        parser.add_argument('--cbhg-highway-units',
                            default=128,
                            type=int,
                            help='Number of highway units in CBHG')
        parser.add_argument('--cbhg-gru-units',
                            default=256,
                            type=int,
                            help='Number of GRU units in CBHG')
        # model (parameter) related
        parser.add_argument('--use-batch-norm',
                            default=True,
                            type=strtobool,
                            help='Whether to use batch normalization')
        parser.add_argument(
            '--use-concate',
            default=True,
            type=strtobool,
            help='Whether to concatenate encoder embedding with decoder outputs'
        )
        parser.add_argument(
            '--use-residual',
            default=True,
            type=strtobool,
            help='Whether to use residual connection in conv layer')
        parser.add_argument('--dropout-rate',
                            default=0.5,
                            type=float,
                            help='Dropout rate')
        parser.add_argument('--zoneout-rate',
                            default=0.1,
                            type=float,
                            help='Zoneout rate')
        parser.add_argument('--reduction-factor',
                            default=1,
                            type=int,
                            help='Reduction factor')
        # loss related
        parser.add_argument(
            '--use-masking',
            default=False,
            type=strtobool,
            help='Whether to use masking in calculation of loss')
        parser.add_argument(
            '--bce-pos-weight',
            default=20.0,
            type=float,
            help=
            'Positive sample weight in BCE calculation (only for use-masking=True)'
        )
        parser.add_argument("--use-guided-attn-loss",
                            default=False,
                            type=strtobool,
                            help="Whether to use guided attention loss")
        parser.add_argument("--guided-attn-loss-sigma",
                            default=0.4,
                            type=float,
                            help="Sigma in guided attention loss")
        return

    def __init__(self, idim, odim, args):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = getattr(args, "use_guided_attn_loss",
                                            False)

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim, args.dunits, args.adim,
                               args.aconv_chans, args.aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(args)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma)
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(args)

    def forward(self,
                xs,
                ilens,
                ys,
                labels,
                olens,
                spembs=None,
                spcs=None,
                *args,
                **kwargs):
        """Tacotron2 forward computation

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param torch.Tensor ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor olens: batch of the lengths of each target (B)
        :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim)
        :param torch.Tensor spcs: batch of groundtruth spectrogram (B, Lmax, spc_dim)
        :return: loss value
        :rtype: torch.Tensor
        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs, ilens)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # caluculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {
                'l1_loss': l1_loss.item()
            },
            {
                'mse_loss': mse_loss.item()
            },
            {
                'bce_loss': bce_loss.item()
            },
        ]

        # caluculate attention loss
        if self.use_guided_attn_loss:
            attn_loss = self.attn_loss(att_ws, ilens, olens)
            loss = loss + attn_loss
            report_keys += [
                {
                    'attn_loss': attn_loss.item()
                },
            ]

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != spcs.shape[1]:
                spcs = spcs[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(
                cbhg_outs, spcs, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {
                    'cbhg_l1_loss': cbhg_l1_loss.item()
                },
                {
                    'cbhg_mse_loss': cbhg_mse_loss.item()
                },
            ]

        report_keys += [{'loss': loss.item()}]
        self.reporter.report(report_keys)

        return loss

    def inference(self, x, inference_args, spemb=None, *args, **kwargs):
        """Generates the sequence of features given the sequences of characters

        :param torch.Tensor x: the sequence of characters (T)
        :param Namespace inference_args: argments containing following attributes
            (float) threshold: threshold in inference
            (float) minlenratio: minimum length ratio in inference
            (float) maxlenratio: maximum length ratio in inference
        :param torch.Tensor spemb: speaker embedding vector (spk_embed_dim)
        :return: the sequence of features (L, odim)
        :rtype: torch.Tensor
        :return: the sequence of stop probabilities (L)
        :rtype: torch.Tensor
        :return: the sequence of attention weight (L, T)
        :rtype: torch.Tensor
        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio

        # inference
        h = self.enc.inference(x)
        if self.spk_embed_dim is not None:
            spemb = F.normalize(spemb,
                                dim=0).unsqueeze(0).expand(h.size(0), -1)
            h = torch.cat([h, spemb], dim=-1)
        outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio,
                                                 maxlenratio)

        if self.use_cbhg:
            cbhg_outs = self.cbhg.inference(outs)
            return cbhg_outs, probs, att_ws
        else:
            return outs, probs, att_ws

    def calculate_all_attentions(self,
                                 xs,
                                 ilens,
                                 ys,
                                 spembs=None,
                                 *args,
                                 **kwargs):
        """Tacotron2 attention weight computation

        :param torch.Tensor xs: batch of padded character ids (B, Tmax)
        :param torch.Tensor ilens: list of lengths of each input batch (B)
        :param torch.Tensor ys: batch of padded target features (B, Lmax, odim)
        :param torch.Tensor spembs: batch of speaker embedding vector (B, spk_embed_dim)
        :return: attention weights (B, Lmax, Tmax)
        :rtype: numpy array
        """
        # check ilens type (should be list of int)
        if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
            ilens = list(map(int, ilens))

        self.eval()
        with torch.no_grad():
            hs, hlens = self.enc(xs, ilens)
            if self.spk_embed_dim is not None:
                spembs = F.normalize(spembs).unsqueeze(1).expand(
                    -1, hs.size(1), -1)
                hs = torch.cat([hs, spembs], dim=-1)
            att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
        self.train()

        return att_ws.cpu().numpy()

    @property
    def base_plot_keys(self):
        """base key names to plot during training. keys should match what `chainer.reporter` reports

        if you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.

        :rtype list[str] plot_keys: base keys to plot during training
        """
        plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss']
        if self.use_guided_attn_loss:
            plot_keys += ['attn_loss']
        if self.use_cbhg:
            plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss']
        return plot_keys
示例#6
0
class Tacotron2(TTSInterface, torch.nn.Module):
    """Tacotron2 module for end-to-end text-to-speech (E2E-TTS).

    This is a module of Spectrogram prediction network in Tacotron2 described in `Natural TTS Synthesis
    by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts the sequence of characters
    into the sequence of Mel-filterbanks.

    .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
       https://arxiv.org/abs/1712.05884

    """

    @staticmethod
    def add_arguments(parser):
        """Add model-specific arguments to the parser."""
        group = parser.add_argument_group("tacotron 2 model setting")
        # encoder
        group.add_argument('--embed-dim', default=512, type=int,
                           help='Number of dimension of embedding')
        group.add_argument('--elayers', default=1, type=int,
                           help='Number of encoder layers')
        group.add_argument('--eunits', '-u', default=512, type=int,
                           help='Number of encoder hidden units')
        group.add_argument('--econv-layers', default=3, type=int,
                           help='Number of encoder convolution layers')
        group.add_argument('--econv-chans', default=512, type=int,
                           help='Number of encoder convolution channels')
        group.add_argument('--econv-filts', default=5, type=int,
                           help='Filter size of encoder convolution')
        # attention
        group.add_argument('--atype', default="location", type=str,
                           choices=["forward_ta", "forward", "location"],
                           help='Type of attention mechanism')
        group.add_argument('--adim', default=512, type=int,
                           help='Number of attention transformation dimensions')
        group.add_argument('--aconv-chans', default=32, type=int,
                           help='Number of attention convolution channels')
        group.add_argument('--aconv-filts', default=15, type=int,
                           help='Filter size of attention convolution')
        group.add_argument('--cumulate-att-w', default=True, type=strtobool,
                           help="Whether or not to cumulate attention weights")
        # decoder
        group.add_argument('--dlayers', default=2, type=int,
                           help='Number of decoder layers')
        group.add_argument('--dunits', default=1024, type=int,
                           help='Number of decoder hidden units')
        group.add_argument('--prenet-layers', default=2, type=int,
                           help='Number of prenet layers')
        group.add_argument('--prenet-units', default=256, type=int,
                           help='Number of prenet hidden units')
        group.add_argument('--postnet-layers', default=5, type=int,
                           help='Number of postnet layers')
        group.add_argument('--postnet-chans', default=512, type=int,
                           help='Number of postnet channels')
        group.add_argument('--postnet-filts', default=5, type=int,
                           help='Filter size of postnet')
        group.add_argument('--output-activation', default=None, type=str, nargs='?',
                           help='Output activation function')
        # cbhg
        group.add_argument('--use-cbhg', default=False, type=strtobool,
                           help='Whether to use CBHG module')
        group.add_argument('--cbhg-conv-bank-layers', default=8, type=int,
                           help='Number of convoluional bank layers in CBHG')
        group.add_argument('--cbhg-conv-bank-chans', default=128, type=int,
                           help='Number of convoluional bank channles in CBHG')
        group.add_argument('--cbhg-conv-proj-filts', default=3, type=int,
                           help='Filter size of convoluional projection layer in CBHG')
        group.add_argument('--cbhg-conv-proj-chans', default=256, type=int,
                           help='Number of convoluional projection channels in CBHG')
        group.add_argument('--cbhg-highway-layers', default=4, type=int,
                           help='Number of highway layers in CBHG')
        group.add_argument('--cbhg-highway-units', default=128, type=int,
                           help='Number of highway units in CBHG')
        group.add_argument('--cbhg-gru-units', default=256, type=int,
                           help='Number of GRU units in CBHG')
        # model (parameter) related
        group.add_argument('--use-batch-norm', default=True, type=strtobool,
                           help='Whether to use batch normalization')
        group.add_argument('--use-concate', default=True, type=strtobool,
                           help='Whether to concatenate encoder embedding with decoder outputs')
        group.add_argument('--use-residual', default=True, type=strtobool,
                           help='Whether to use residual connection in conv layer')
        group.add_argument('--dropout-rate', default=0.5, type=float,
                           help='Dropout rate')
        group.add_argument('--zoneout-rate', default=0.1, type=float,
                           help='Zoneout rate')
        group.add_argument('--reduction-factor', default=1, type=int,
                           help='Reduction factor')
        group.add_argument("--spk-embed-dim", default=None, type=int,
                           help="Number of speaker embedding dimensions")
        group.add_argument("--spc-dim", default=None, type=int,
                           help="Number of spectrogram dimensions")
        # loss related
        group.add_argument('--use-masking', default=False, type=strtobool,
                           help='Whether to use masking in calculation of loss')
        group.add_argument('--bce-pos-weight', default=20.0, type=float,
                           help='Positive sample weight in BCE calculation (only for use-masking=True)')
        group.add_argument("--use-guided-attn-loss", default=False, type=strtobool,
                           help="Whether to use guided attention loss")
        group.add_argument("--guided-attn-loss-sigma", default=0.4, type=float,
                           help="Sigma in guided attention loss")
        group.add_argument("--guided-attn-loss-lambda", default=1.0, type=float,
                           help="Lambda in guided attention loss")
        return parser

    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - embed_dim (int): Dimension of character embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int): The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int): The number of layers of highway network in CBHG.
                - cbhg_highway_units (int): The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' % args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           use_residual=args.use_residual,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim,
                         args.dunits,
                         args.adim,
                         args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim,
                             args.dunits,
                             args.adim,
                             args.aconv_chans,
                             args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled in forward attention.")
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim,
                               args.dunits,
                               args.adim,
                               args.aconv_chans,
                               args.aconv_filts,
                               odim)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled in forward attention.")
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking,
                                        bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)

    def forward(self, xs, ilens, ys, labels, olens, spembs=None, spcs=None, *args, **kwargs):
        """Calculate forward propagation.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).
            spcs (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim).

        Returns:
            Tensor: Loss value.

        """
        # remove unnecessary padded part (for multi-gpus)
        max_in = max(ilens)
        max_out = max(olens)
        if max_in != xs.shape[1]:
            xs = xs[:, :max_in]
        if max_out != ys.shape[1]:
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs, ilens)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # modifiy mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new([olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # caluculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(
            after_outs, before_outs, logits, ys, labels, olens)
        loss = l1_loss + mse_loss + bce_loss
        report_keys = [
            {'l1_loss': l1_loss.item()},
            {'mse_loss': mse_loss.item()},
            {'bce_loss': bce_loss.item()},
        ]

        # caluculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new([olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            report_keys += [
                {'attn_loss': attn_loss.item()},
            ]

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != spcs.shape[1]:
                spcs = spcs[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(cbhg_outs, spcs, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            report_keys += [
                {'cbhg_l1_loss': cbhg_l1_loss.item()},
                {'cbhg_mse_loss': cbhg_mse_loss.item()},
            ]

        report_keys += [{'loss': loss.item()}]
        self.reporter.report(report_keys)

        return loss

    def inference(self, x, inference_args, spemb=None, *args, **kwargs):
        """Generate the sequence of features given the sequences of characters.

        Args:
            x (Tensor): Input sequence of characters (T,).
            inference_args (Namespace):
                - threshold (float): Threshold in inference.
                - minlenratio (float): Minimum length ratio in inference.
                - maxlenratio (float): Maximum length ratio in inference.
            spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim).

        Returns:
            Tensor: Output sequence of features (L, odim).
            Tensor: Output sequence of stop probabilities (L,).
            Tensor: Attention weights (L, T).

        """
        # get options
        threshold = inference_args.threshold
        minlenratio = inference_args.minlenratio
        maxlenratio = inference_args.maxlenratio

        # inference
        h = self.enc.inference(x)
        if self.spk_embed_dim is not None:
            spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1)
            h = torch.cat([h, spemb], dim=-1)
        outs, probs, att_ws = self.dec.inference(h, threshold, minlenratio, maxlenratio)

        if self.use_cbhg:
            cbhg_outs = self.cbhg.inference(outs)
            return cbhg_outs, probs, att_ws
        else:
            return outs, probs, att_ws

    def calculate_all_attentions(self, xs, ilens, ys, spembs=None, *args, **kwargs):
        """Calculate all of the attention weights.

        Args:
            xs (Tensor): Batch of padded character ids (B, Tmax).
            ilens (LongTensor): Batch of lengths of each input batch (B,).
            ys (Tensor): Batch of padded target features (B, Lmax, odim).
            olens (LongTensor): Batch of the lengths of each target (B,).
            spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim).

        Returns:
            numpy.ndarray: Batch of attention weights (B, Lmax, Tmax).

        """
        # check ilens type (should be list of int)
        if isinstance(ilens, torch.Tensor) or isinstance(ilens, np.ndarray):
            ilens = list(map(int, ilens))

        self.eval()
        with torch.no_grad():
            hs, hlens = self.enc(xs, ilens)
            if self.spk_embed_dim is not None:
                spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1)
                hs = torch.cat([hs, spembs], dim=-1)
            att_ws = self.dec.calculate_all_attentions(hs, hlens, ys)
        self.train()

        return att_ws.cpu().numpy()

    @property
    def base_plot_keys(self):
        """Return base key names to plot during training. keys should match what `chainer.reporter` reports.

        If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values.
        also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values.

        Returns:
            list: List of strings which are base keys to plot during training.

        """
        plot_keys = ['loss', 'l1_loss', 'mse_loss', 'bce_loss']
        if self.use_guided_attn_loss:
            plot_keys += ['attn_loss']
        if self.use_cbhg:
            plot_keys += ['cbhg_l1_loss', 'cbhg_mse_loss']
        return plot_keys
示例#7
0
    def __init__(
        self,
        idim: int,
        odim: int,
        embed_dim: int = 512,
        elayers: int = 1,
        eunits: int = 512,
        econv_layers: int = 3,
        econv_chans: int = 512,
        econv_filts: int = 5,
        atype: str = "location",
        adim: int = 512,
        aconv_chans: int = 32,
        aconv_filts: int = 15,
        cumulate_att_w: bool = True,
        dlayers: int = 2,
        dunits: int = 1024,
        prenet_layers: int = 2,
        prenet_units: int = 256,
        postnet_layers: int = 5,
        postnet_chans: int = 512,
        postnet_filts: int = 5,
        output_activation: str = None,
        use_cbhg: bool = False,
        cbhg_conv_bank_layers: int = 8,
        cbhg_conv_bank_chans: int = 128,
        cbhg_conv_proj_filts: int = 3,
        cbhg_conv_proj_chans: int = 256,
        cbhg_highway_layers: int = 4,
        cbhg_highway_units: int = 128,
        cbhg_gru_units: int = 256,
        use_batch_norm: bool = True,
        use_concate: bool = True,
        use_residual: bool = False,
        dropout_rate: float = 0.5,
        zoneout_rate: float = 0.1,
        reduction_factor: int = 1,
        spk_embed_dim: int = None,
        spc_dim: int = None,
        use_masking: bool = True,
        use_weighted_masking: bool = False,
        bce_pos_weight: float = 5.0,
        use_guided_attn_loss: bool = True,
        guided_attn_loss_sigma: float = 0.4,
        guided_attn_loss_lambda: float = 1.0,
    ):
        assert check_argument_types()
        super().__init__()

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.eos = idim - 1
        self.spk_embed_dim = spk_embed_dim
        self.cumulate_att_w = cumulate_att_w
        self.reduction_factor = reduction_factor
        self.use_cbhg = use_cbhg
        self.use_guided_attn_loss = use_guided_attn_loss

        # define activation function for the final output
        if output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, output_activation):
            self.output_activation_fn = getattr(F, output_activation)
        else:
            raise ValueError(f"there is no such an activation function. "
                             f"({output_activation})")

        # set padding idx
        padding_idx = 0
        self.padding_idx = padding_idx

        # define network modules
        self.enc = Encoder(
            idim=idim,
            embed_dim=embed_dim,
            elayers=elayers,
            eunits=eunits,
            econv_layers=econv_layers,
            econv_chans=econv_chans,
            econv_filts=econv_filts,
            use_batch_norm=use_batch_norm,
            use_residual=use_residual,
            dropout_rate=dropout_rate,
            padding_idx=padding_idx,
        )

        dec_idim = eunits if spk_embed_dim is None else eunits + spk_embed_dim
        if atype == "location":
            att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts)
        elif atype == "forward":
            att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled "
                                "in forward attention.")
                self.cumulate_att_w = False
        elif atype == "forward_ta":
            att = AttForwardTA(dec_idim, dunits, adim, aconv_chans,
                               aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled "
                                "in forward attention.")
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=dlayers,
            dunits=dunits,
            prenet_layers=prenet_layers,
            prenet_units=prenet_units,
            postnet_layers=postnet_layers,
            postnet_chans=postnet_chans,
            postnet_filts=postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=use_batch_norm,
            use_concate=use_concate,
            dropout_rate=dropout_rate,
            zoneout_rate=zoneout_rate,
            reduction_factor=reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=use_masking,
            use_weighted_masking=use_weighted_masking,
            bce_pos_weight=bce_pos_weight,
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=guided_attn_loss_sigma,
                alpha=guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=spc_dim,
                conv_bank_layers=cbhg_conv_bank_layers,
                conv_bank_chans=cbhg_conv_bank_chans,
                conv_proj_filts=cbhg_conv_proj_filts,
                conv_proj_chans=cbhg_conv_proj_chans,
                highway_layers=cbhg_highway_layers,
                highway_units=cbhg_highway_units,
                gru_units=cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=use_masking)
示例#8
0
class Tacotron2(AbsTTS):
    """Tacotron2 module for end-to-end text-to-speech.

    This is a module of Spectrogram prediction network in Tacotron2 described
    in `Natural TTS Synthesis
    by Conditioning WaveNet on Mel Spectrogram Predictions`_, which converts
    the sequence of characters into the sequence of Mel-filterbanks.

    .. _`Natural TTS Synthesis by Conditioning WaveNet on Mel Spectrogram Predictions`:
       https://arxiv.org/abs/1712.05884

    Args:
        idim: Dimension of the inputs.
        odim: Dimension of the outputs.
        spk_embed_dim: Dimension of the speaker embedding.
        embed_dim: Dimension of character embedding.
        elayers: The number of encoder blstm layers.
        eunits: The number of encoder blstm units.
        econv_layers: The number of encoder conv layers.
        econv_filts: The number of encoder conv filter size.
        econv_chans: The number of encoder conv filter channels.
        dlayers: The number of decoder lstm layers.
        dunits: The number of decoder lstm units.
        prenet_layers: The number of prenet layers.
        prenet_units: The number of prenet units.
        postnet_layers: The number of postnet layers.
        postnet_filts: The number of postnet filter size.
        postnet_chans: The number of postnet filter channels.
        output_activation: The name of activation function for outputs.
        adim: The number of dimension of mlp in attention.
        aconv_chans: The number of attention conv filter channels.
        aconv_filts: The number of attention conv filter size.
        cumulate_att_w: Whether to cumulate previous attention weight.
        use_batch_norm: Whether to use batch normalization.
        use_concate: Whether to concatenate encoder embedding with decoder
            lstm outputs.
        dropout_rate: Dropout rate.
        zoneout_rate: Zoneout rate.
        reduction_factor: Reduction factor.
        spk_embed_dim: Number of speaker embedding dimenstions.
        spc_dim: Number of spectrogram embedding dimenstions
            (only for use_cbhg=True).
        use_cbhg: Whether to use CBHG module.
        cbhg_conv_bank_layers: The number of convoluional banks in CBHG.
        cbhg_conv_bank_chans: The number of channels of convolutional bank in
            CBHG.
        cbhg_proj_filts: The number of filter size of projection layeri in
            CBHG.
        cbhg_proj_chans: The number of channels of projection layer in CBHG.
        cbhg_highway_layers: The number of layers of highway network in CBHG.
        cbhg_highway_units: The number of units of highway network in CBHG.
        cbhg_gru_units: The number of units of GRU in CBHG.
        use_masking: Whether to mask padded part in loss calculation.
        use_weighted_masking: Whether to apply weighted masking in
            loss calculation.
        bce_pos_weight: Weight of positive sample of stop token
            (only for use_masking=True).
        use_guided_attn_loss: Whether to use guided attention loss.
        guided_attn_loss_sigma: Sigma in guided attention loss.
        guided_attn_loss_lamdba: Lambda in guided attention loss.
    """
    def __init__(
        self,
        idim: int,
        odim: int,
        embed_dim: int = 512,
        elayers: int = 1,
        eunits: int = 512,
        econv_layers: int = 3,
        econv_chans: int = 512,
        econv_filts: int = 5,
        atype: str = "location",
        adim: int = 512,
        aconv_chans: int = 32,
        aconv_filts: int = 15,
        cumulate_att_w: bool = True,
        dlayers: int = 2,
        dunits: int = 1024,
        prenet_layers: int = 2,
        prenet_units: int = 256,
        postnet_layers: int = 5,
        postnet_chans: int = 512,
        postnet_filts: int = 5,
        output_activation: str = None,
        use_cbhg: bool = False,
        cbhg_conv_bank_layers: int = 8,
        cbhg_conv_bank_chans: int = 128,
        cbhg_conv_proj_filts: int = 3,
        cbhg_conv_proj_chans: int = 256,
        cbhg_highway_layers: int = 4,
        cbhg_highway_units: int = 128,
        cbhg_gru_units: int = 256,
        use_batch_norm: bool = True,
        use_concate: bool = True,
        use_residual: bool = False,
        dropout_rate: float = 0.5,
        zoneout_rate: float = 0.1,
        reduction_factor: int = 1,
        spk_embed_dim: int = None,
        spc_dim: int = None,
        use_masking: bool = True,
        use_weighted_masking: bool = False,
        bce_pos_weight: float = 5.0,
        use_guided_attn_loss: bool = True,
        guided_attn_loss_sigma: float = 0.4,
        guided_attn_loss_lambda: float = 1.0,
    ):
        assert check_argument_types()
        super().__init__()

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.eos = idim - 1
        self.spk_embed_dim = spk_embed_dim
        self.cumulate_att_w = cumulate_att_w
        self.reduction_factor = reduction_factor
        self.use_cbhg = use_cbhg
        self.use_guided_attn_loss = use_guided_attn_loss

        # define activation function for the final output
        if output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, output_activation):
            self.output_activation_fn = getattr(F, output_activation)
        else:
            raise ValueError(f"there is no such an activation function. "
                             f"({output_activation})")

        # set padding idx
        padding_idx = 0
        self.padding_idx = padding_idx

        # define network modules
        self.enc = Encoder(
            idim=idim,
            embed_dim=embed_dim,
            elayers=elayers,
            eunits=eunits,
            econv_layers=econv_layers,
            econv_chans=econv_chans,
            econv_filts=econv_filts,
            use_batch_norm=use_batch_norm,
            use_residual=use_residual,
            dropout_rate=dropout_rate,
            padding_idx=padding_idx,
        )

        dec_idim = eunits if spk_embed_dim is None else eunits + spk_embed_dim
        if atype == "location":
            att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts)
        elif atype == "forward":
            att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled "
                                "in forward attention.")
                self.cumulate_att_w = False
        elif atype == "forward_ta":
            att = AttForwardTA(dec_idim, dunits, adim, aconv_chans,
                               aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning("cumulation of attention weights is disabled "
                                "in forward attention.")
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=dlayers,
            dunits=dunits,
            prenet_layers=prenet_layers,
            prenet_units=prenet_units,
            postnet_layers=postnet_layers,
            postnet_chans=postnet_chans,
            postnet_filts=postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=use_batch_norm,
            use_concate=use_concate,
            dropout_rate=dropout_rate,
            zoneout_rate=zoneout_rate,
            reduction_factor=reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=use_masking,
            use_weighted_masking=use_weighted_masking,
            bce_pos_weight=bce_pos_weight,
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=guided_attn_loss_sigma,
                alpha=guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=spc_dim,
                conv_bank_layers=cbhg_conv_bank_layers,
                conv_bank_chans=cbhg_conv_bank_chans,
                conv_proj_filts=cbhg_conv_proj_filts,
                conv_proj_chans=cbhg_conv_proj_chans,
                highway_layers=cbhg_highway_layers,
                highway_units=cbhg_highway_units,
                gru_units=cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=use_masking)

    def forward(
        self,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        spembs: torch.Tensor = None,
        spcs: torch.Tensor = None,
        spcs_lengths: torch.Tensor = None,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Calculate forward propagation.

        Args:
            text: Batch of padded character ids (B, Tmax).
            text_lengths: Batch of lengths of each input batch (B,).
            speech: Batch of padded target features (B, Lmax, odim).
            speech_lengths: Batch of the lengths of each target (B,).
            spembs: Batch of speaker embedding vectors (B, spk_embed_dim).
            spcs: Batch of ground-truth spectrogram (B, Lmax, spc_dim).
            spcs_lengths:
        """
        text = text[:, :text_lengths.max()]  # for data-parallel
        speech = speech[:, :speech_lengths.max()]  # for data-parallel

        batch_size = text.size(0)
        # Add eos at the last of sequence
        xs = F.pad(text, [0, 1], "constant", 0.0)
        for i, l in enumerate(text_lengths):
            xs[i, l] = self.eos
        ilens = text_lengths + 1

        ys = speech
        olens = speech_lengths

        # make labels for stop prediction
        labels = make_pad_mask(olens).to(ys.device, ys.dtype)

        # calculate tacotron2 outputs
        hs, hlens = self.enc(xs, ilens)
        if self.spk_embed_dim is not None:
            spembs = F.normalize(spembs).unsqueeze(1).expand(
                -1, hs.size(1), -1)
            hs = torch.cat([hs, spembs], dim=-1)
        after_outs, before_outs, logits, att_ws = self.dec(hs, hlens, ys)

        # modify mod part of groundtruth
        if self.reduction_factor > 1:
            olens = olens.new(
                [olen - olen % self.reduction_factor for olen in olens])
            max_out = max(olens)
            ys = ys[:, :max_out]
            labels = labels[:, :max_out]
            labels[:, -1] = 1.0  # make sure at least one frame has 1

        # calculate taco2 loss
        l1_loss, mse_loss, bce_loss = self.taco2_loss(after_outs, before_outs,
                                                      logits, ys, labels,
                                                      olens)
        loss = l1_loss + mse_loss + bce_loss

        stats = dict(
            l1_loss=l1_loss.item(),
            mse_loss=mse_loss.item(),
            bce_loss=bce_loss.item(),
        )

        # calculate attention loss
        if self.use_guided_attn_loss:
            # NOTE(kan-bayashi): length of output for auto-regressive
            # input will be changed when r > 1
            if self.reduction_factor > 1:
                olens_in = olens.new(
                    [olen // self.reduction_factor for olen in olens])
            else:
                olens_in = olens
            attn_loss = self.attn_loss(att_ws, ilens, olens_in)
            loss = loss + attn_loss
            stats.update(attn_loss=attn_loss.item())

        # caluculate cbhg loss
        if self.use_cbhg:
            # remove unnecessary padded part (for multi-gpus)
            if max_out != spcs.shape[1]:
                spcs = spcs[:, :max_out]

            # caluculate cbhg outputs & loss and report them
            cbhg_outs, _ = self.cbhg(after_outs, olens)
            cbhg_l1_loss, cbhg_mse_loss = self.cbhg_loss(
                cbhg_outs, spcs, olens)
            loss = loss + cbhg_l1_loss + cbhg_mse_loss
            stats.update(
                cbhg_l1_loss=cbhg_l1_loss.item(),
                cbhg_mse_loss=cbhg_mse_loss.item(),
            )

        stats.update(loss=loss.item())

        loss, stats, weight = force_gatherable((loss, stats, batch_size),
                                               loss.device)
        return loss, stats, weight

    def inference(
        self,
        text: torch.Tensor,
        spembs: torch.Tensor = None,
        threshold: float = 0.5,
        minlenratio: float = 0.0,
        maxlenratio: float = 10.0,
        use_att_constraint: bool = False,
        backward_window: int = 1,
        forward_window: int = 3,
    ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Generate the sequence of features given the sequences of characters.

        Args:
            text: Input sequence of characters (T,).
            spembs: Speaker embedding vector (spk_embed_dim,).
            threshold: Threshold in inference.
            minlenratio: Minimum length ratio in inference.
            maxlenratio: Maximum length ratio in inference.
            use_att_constraint: Whether to apply attention constraint.
            backward_window: Backward window in attention constraint.
            forward_window: Forward window in attention constraint.

        Returns:
            Tensor: Output sequence of features (L, odim).
            Tensor: Output sequence of stop probabilities (L,).
            Tensor: Attention weights (L, T).

        """
        x = text
        spemb = spembs

        # inference
        h = self.enc.inference(x)
        if self.spk_embed_dim is not None:
            spemb = F.normalize(spemb,
                                dim=0).unsqueeze(0).expand(h.size(0), -1)
            h = torch.cat([h, spemb], dim=-1)
        outs, probs, att_ws = self.dec.inference(
            h,
            threshold=threshold,
            minlenratio=minlenratio,
            maxlenratio=maxlenratio,
            use_att_constraint=use_att_constraint,
            backward_window=backward_window,
            forward_window=forward_window,
        )

        if self.use_cbhg:
            cbhg_outs = self.cbhg.inference(outs)
            return cbhg_outs, probs, att_ws
        else:
            return outs, probs, att_ws
示例#9
0
    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - embed_dim (int): Dimension of character embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimension.
                - spkidloss_weight (float): Weight for the speaker id module when combined to tts loss
                - num_spk (int): Number of speakers in the training data
                - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int): The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int): The number of layers of highway network in CBHG.
                - cbhg_highway_units (int): The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.spkidloss_weight = args.spkidloss_weight
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           use_residual=args.use_residual,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        if args.train_spkid_extractor:
            self.train_spkid_extractor = True
            self.resnet_spkid = E2E_speakerid(input_dim=odim,
                                              output_dim=args.num_spk,
                                              Q=odim - 1,
                                              D=32,
                                              hidden_dim=args.spk_embed_dim,
                                              pooling='mean',
                                              network_type='lde',
                                              distance_type='sqr',
                                              asoftmax=True,
                                              resnet_AvgPool2d_fre_ksize=10)
            self.angle_loss = AngleLoss()
        else:
            self.train_spkid_extractor = False
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim, args.dunits, args.adim,
                               args.aconv_chans, args.aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "noatt":  # This condition is satisfied only when using phone alignment for TTS input (Currently when using phn. ali. in TTS training for learning speaker embedding)
            att = None
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking,
                                        bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)