示例#1
0
    def __init__(self, idim, odim, args=None):
        """Initialize feed-forward Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - elayers (int): Number of encoder layers.
                - eunits (int): Number of encoder hidden units.
                - adim (int): Number of attention transformation dimensions.
                - aheads (int): Number of heads for multi head attention.
                - dlayers (int): Number of decoder layers.
                - dunits (int): Number of decoder hidden units.
                - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
                - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
                - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
                - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
                - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
                - duration_predictor_layers (int): Number of duration predictor layers.
                - duration_predictor_chans (int): Number of duration predictor channels.
                - duration_predictor_kernel_size (int): Kernel size of duration predictor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spk_embed_integration_type: How to integrate speaker embedding.
                - teacher_model (str): Teacher auto-regressive transformer model path.
                - reduction_factor (int): Reduction factor.
                - transformer_init (float): How to initialize transformer parameters.
                - transformer_lr (float): Initial value of learning rate.
                - transformer_warmup_steps (int): Optimizer warmup steps.
                - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
                - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
                - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
                - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
                - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
                - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
                - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
                - use_masking (bool): Whether to apply masking for padded part in loss calculation.
                - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
                - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters.
                - transferred_encoder_module: Encoder module to be initialized using teacher parameters.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.reduction_factor = args.reduction_factor
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=args.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size)

        # define additional projection for speaker embedding
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define duration predictor
        self.duration_predictor = DurationPredictor(
            idim=args.adim,
            n_layers=args.duration_predictor_layers,
            n_chans=args.duration_predictor_chans,
            kernel_size=args.duration_predictor_kernel_size,
            dropout_rate=args.duration_predictor_dropout_rate,
        )

        # define length regulator
        self.length_regulator = LengthRegulator()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=0,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            input_layer=None,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # initialize parameters
        self._reset_parameters(init_type=args.transformer_init,
                               init_enc_alpha=args.initial_encoder_alpha,
                               init_dec_alpha=args.initial_decoder_alpha)

        # define teacher model
        if args.teacher_model is not None:
            self.teacher = self._load_teacher_model(args.teacher_model)
        else:
            self.teacher = None

        # define duration calculator
        if self.teacher is not None:
            self.duration_calculator = DurationCalculator(self.teacher)
        else:
            self.duration_calculator = None

        # transfer teacher parameters
        if self.teacher is not None and args.transfer_encoder_from_teacher:
            self._transfer_from_teacher(args.transferred_encoder_module)

        # define criterions
        self.criterion = FeedForwardTransformerLoss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking)
    def __init__(self, idim, odim, args=None):
        """Initialize TTS-Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - embed_dim (int): Dimension of character embedding.
                - eprenet_conv_layers (int): Number of encoder prenet convolution layers.
                - eprenet_conv_chans (int): Number of encoder prenet convolution channels.
                - eprenet_conv_filts (int): Filter size of encoder prenet convolution.
                - dprenet_layers (int): Number of decoder prenet layers.
                - dprenet_units (int): Number of decoder prenet hidden units.
                - elayers (int): Number of encoder layers.
                - eunits (int): Number of encoder hidden units.
                - adim (int): Number of attention transformation dimensions.
                - aheads (int): Number of heads for multi head attention.
                - dlayers (int): Number of decoder layers.
                - dunits (int): Number of decoder hidden units.
                - postnet_layers (int): Number of postnet layers.
                - postnet_chans (int): Number of postnet channels.
                - postnet_filts (int): Filter size of postnet.
                - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
                - use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
                - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
                - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
                - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
                - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spk_embed_integration_type: How to integrate speaker embedding.
                - transformer_init (float): How to initialize transformer parameters.
                - transformer_lr (float): Initial value of learning rate.
                - transformer_warmup_steps (int): Optimizer warmup steps.
                - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
                - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
                - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
                - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
                - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
                - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
                - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
                - eprenet_dropout_rate (float): Dropout rate in encoder prenet.
                - dprenet_dropout_rate (float): Dropout rate in decoder prenet.
                - postnet_dropout_rate (float): Dropout rate in postnet.
                - use_masking (bool): Whether to apply masking for padded part in loss calculation.
                - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
                - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true).
                - loss_type (str): How to calculate loss.
                - use_guided_attn_loss (bool): Whether to use guided attention loss.
                - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss.
                - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss.
                - modules_applied_guided_attn (list): List of module names to apply guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lambda (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define transformer encoder
        if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(idim=idim,
                              embed_dim=args.embed_dim,
                              elayers=0,
                              econv_layers=args.eprenet_conv_layers,
                              econv_chans=args.eprenet_conv_chans,
                              econv_filts=args.eprenet_conv_filts,
                              use_batch_norm=args.use_batch_norm,
                              dropout_rate=args.eprenet_dropout_rate,
                              padding_idx=padding_idx),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim))
        else:
            encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                     embedding_dim=args.adim,
                                                     padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
        )

        # define projection layer
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(idim=odim,
                              n_layers=args.dprenet_layers,
                              n_units=args.dprenet_units,
                              dropout_rate=args.dprenet_dropout_rate),
                torch.nn.Linear(args.dprenet_units, args.adim))
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # define loss function
        self.criterion = TransformerLoss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
            bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )

        # initialize parameters
        self._reset_parameters(init_type=args.transformer_init,
                               init_enc_alpha=args.initial_encoder_alpha,
                               init_dec_alpha=args.initial_decoder_alpha)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)
    def __init__(self, idim, odim, args):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = getattr(args, "use_guided_attn_loss",
                                            False)

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim, args.dunits, args.adim,
                               args.aconv_chans, args.aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(args)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma)
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(args)
示例#4
0
    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int):
                    Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions
                    (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int):
                    The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int):
                    The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int):
                    The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int):
                    The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int):
                    The number of layers of highway network in CBHG.
                - cbhg_highway_units (int):
                    The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token
                    (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.adim = args.adim
        self.spk_embed_dim = args.spk_embed_dim
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.encoder_reduction_factor = args.encoder_reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss
        self.src_reconstruction_loss_lambda = args.src_reconstruction_loss_lambda
        self.trg_reconstruction_loss_lambda = args.trg_reconstruction_loss_lambda

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError(
                "there is no such an activation function. (%s)" % args.output_activation
            )

        # define network modules
        self.enc = Encoder(
            idim=idim * args.encoder_reduction_factor,
            input_layer="linear",
            elayers=args.elayers,
            eunits=args.eunits,
            econv_layers=args.econv_layers,
            econv_chans=args.econv_chans,
            econv_filts=args.econv_filts,
            use_batch_norm=args.use_batch_norm,
            use_residual=args.use_residual,
            dropout_rate=args.dropout_rate,
        )
        dec_idim = (
            args.eunits
            if args.spk_embed_dim is None
            else args.eunits + args.spk_embed_dim
        )
        if args.atype == "location":
            att = AttLoc(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
        elif args.atype == "forward":
            att = AttForward(
                dec_idim, args.dunits, args.adim, args.aconv_chans, args.aconv_filts
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(
                dec_idim,
                args.dunits,
                args.adim,
                args.aconv_chans,
                args.aconv_filts,
                odim,
            )
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            att=att,
            dlayers=args.dlayers,
            dunits=args.dunits,
            prenet_layers=args.prenet_layers,
            prenet_units=args.prenet_units,
            postnet_layers=args.postnet_layers,
            postnet_chans=args.postnet_chans,
            postnet_filts=args.postnet_filts,
            output_activation_fn=self.output_activation_fn,
            cumulate_att_w=self.cumulate_att_w,
            use_batch_norm=args.use_batch_norm,
            use_concate=args.use_concate,
            dropout_rate=args.dropout_rate,
            zoneout_rate=args.zoneout_rate,
            reduction_factor=args.reduction_factor,
        )
        self.taco2_loss = Tacotron2Loss(
            use_masking=args.use_masking, bce_pos_weight=args.bce_pos_weight
        )
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(
                idim=odim,
                odim=args.spc_dim,
                conv_bank_layers=args.cbhg_conv_bank_layers,
                conv_bank_chans=args.cbhg_conv_bank_chans,
                conv_proj_filts=args.cbhg_conv_proj_filts,
                conv_proj_chans=args.cbhg_conv_proj_chans,
                highway_layers=args.cbhg_highway_layers,
                highway_units=args.cbhg_highway_units,
                gru_units=args.cbhg_gru_units,
            )
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)
        if self.src_reconstruction_loss_lambda > 0:
            self.src_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.src_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, idim * args.encoder_reduction_factor
            )

            self.src_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)
        if self.trg_reconstruction_loss_lambda > 0:
            self.trg_reconstructor = Encoder(
                idim=dec_idim,
                input_layer="linear",
                elayers=args.elayers,
                eunits=args.eunits,
                econv_layers=args.econv_layers,
                econv_chans=args.econv_chans,
                econv_filts=args.econv_filts,
                use_batch_norm=args.use_batch_norm,
                use_residual=args.use_residual,
                dropout_rate=args.dropout_rate,
            )
            self.trg_reconstructor_linear = torch.nn.Linear(
                args.econv_chans, odim * args.reduction_factor
            )
            self.trg_reconstruction_loss = CBHGLoss(use_masking=args.use_masking)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)
示例#5
0
    def __init__(self, idim, odim, args=None):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.reduction_factor = args.reduction_factor
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.use_masking = args.use_masking
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type

        # TODO(kan-bayashi): support reduction_factor > 1
        if self.reduction_factor != 1:
            raise NotImplementedError("Support only reduction_factor = 1.")

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define encoder
        encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                 embedding_dim=args.adim,
                                                 padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size)

        # define additional projection for speaker embedding
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define duration predictor
        self.duration_predictor = DurationPredictor(
            idim=args.adim,
            n_layers=args.duration_predictor_layers,
            n_chans=args.duration_predictor_chans,
            kernel_size=args.duration_predictor_kernel_size,
            dropout_rate=args.duration_predictor_dropout_rate,
        )

        # define length regulator
        self.length_regulator = LengthRegulator()

        # define decoder
        # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder
        self.decoder = Encoder(
            idim=0,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            input_layer=None,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)

        # initialize parameters
        self._reset_parameters(init_type=args.transformer_init,
                               init_enc_alpha=args.initial_encoder_alpha,
                               init_dec_alpha=args.initial_decoder_alpha)

        # define teacher model
        if args.teacher_model is not None:
            self.teacher = self._load_teacher_model(args.teacher_model)
        else:
            self.teacher = None

        # define duration calculator
        if self.teacher is not None:
            self.duration_calculator = DurationCalculator(self.teacher)
        else:
            self.duration_calculator = None

        # transfer teacher parameters
        if self.teacher is not None and args.transfer_encoder_from_teacher:
            self._transfer_from_teacher(args.transferred_encoder_module)

        # define criterions
        self.duration_criterion = DurationPredictorLoss()
        # TODO(kan-bayashi): support knowledge distillation loss
        self.criterion = torch.nn.L1Loss()
示例#6
0
    def __init__(self, idim, odim, args=None, com_args=None):
        """Initialize Tacotron2 module.
        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - spk_embed_dim (int): Dimension of the speaker embedding.
                - embed_dim (int): Dimension of character embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int): Whether to concatenate encoder embedding
                    with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spc_dim (int): Number of spectrogram embedding dimenstions
                    (only for use_cbhg=True)
                - use_masking (bool):
                    Whether to apply masking for padded part in loss calculation.
                - use_weighted_masking (bool):
                    Whether to apply weighted masking in loss calculation.
                - duration_predictor_layers (int): Number of duration predictor layers.
                - duration_predictor_chans (int): Number of duration predictor channels.
                - duration_predictor_kernel_size (int):
                    Kernel size of duration predictor.
        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        args = vars(args)
        if 'use_fe_condition' not in args.keys():
            args['use_fe_condition'] = com_args.use_fe_condition
        if 'append_position' not in args.keys():
            args['append_position'] = com_args.append_position

        args = argparse.Namespace(**args)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.embed_dim = args.embed_dim
        self.spk_embed_dim = args.spk_embed_dim
        self.reduction_factor = args.reduction_factor
        self.use_fe_condition = args.use_fe_condition
        self.append_position = args.append_position

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError("there is no such an activation function. (%s)" %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(
            idim=idim,
            embed_dim=args.embed_dim,
            elayers=args.elayers,
            eunits=args.eunits,
            econv_layers=args.econv_layers,
            econv_chans=args.econv_chans,
            econv_filts=args.econv_filts,
            use_batch_norm=args.use_batch_norm,
            use_residual=args.use_residual,
            dropout_rate=args.dropout_rate,
            padding_idx=padding_idx,
            resume=args.encoder_resume,
        )
        dec_idim = (args.eunits if args.spk_embed_dim is None else
                    args.eunits + args.spk_embed_dim)

        self.dec = Decoder(
            idim=dec_idim,
            odim=odim,
            dlayers=args.dlayers,
            dunits=args.dunits,
            prenet_layers=args.prenet_layers,
            prenet_units=args.prenet_units,
            postnet_layers=args.postnet_layers,
            postnet_chans=args.postnet_chans,
            postnet_filts=args.postnet_filts,
            output_activation_fn=self.output_activation_fn,
            use_batch_norm=args.use_batch_norm,
            use_concate=args.use_concate,
            dropout_rate=args.dropout_rate,
            zoneout_rate=args.zoneout_rate,
            reduction_factor=args.reduction_factor,
            use_fe_condition=args.use_fe_condition,
            append_position=args.append_position,
        )

        self.duration_predictor = DurationPredictor(
            idim=dec_idim,
            n_layers=args.duration_predictor_layers,
            n_chans=args.duration_predictor_chans,
            kernel_size=args.duration_predictor_kernel_size,
            dropout_rate=args.duration_predictor_dropout_rate,
        )
        reduction = 'none' if args.use_weighted_masking else 'mean'
        self.duration_criterion = DurationPredictorLoss(reduction=reduction)

        #-------------- picth/energy predictor definition ---------------#
        if self.use_fe_condition:
            output_dim = 1
            # pitch prediction
            pitch_predictor_layers = 2
            pitch_predictor_chans = 384
            pitch_predictor_kernel_size = 3
            pitch_predictor_dropout_rate = 0.5
            pitch_embed_kernel_size = 9
            pitch_embed_dropout_rate = 0.5
            self.stop_gradient_from_pitch_predictor = False
            self.pitch_predictor = VariancePredictor(
                idim=dec_idim,
                n_layers=pitch_predictor_layers,
                n_chans=pitch_predictor_chans,
                kernel_size=pitch_predictor_kernel_size,
                dropout_rate=pitch_predictor_dropout_rate,
                output_dim=output_dim,
            )
            self.pitch_embed = torch.nn.Sequential(
                torch.nn.Conv1d(
                    in_channels=1,
                    out_channels=dec_idim,
                    kernel_size=pitch_embed_kernel_size,
                    padding=(pitch_embed_kernel_size - 1) // 2,
                ),
                torch.nn.Dropout(pitch_embed_dropout_rate),
            )
            # energy prediction
            energy_predictor_layers = 2
            energy_predictor_chans = 384
            energy_predictor_kernel_size = 3
            energy_predictor_dropout_rate = 0.5
            energy_embed_kernel_size = 9
            energy_embed_dropout_rate = 0.5
            self.stop_gradient_from_energy_predictor = False
            self.energy_predictor = VariancePredictor(
                idim=dec_idim,
                n_layers=energy_predictor_layers,
                n_chans=energy_predictor_chans,
                kernel_size=energy_predictor_kernel_size,
                dropout_rate=energy_predictor_dropout_rate,
                output_dim=output_dim,
            )
            self.energy_embed = torch.nn.Sequential(
                torch.nn.Conv1d(
                    in_channels=1,
                    out_channels=dec_idim,
                    kernel_size=energy_embed_kernel_size,
                    padding=(energy_embed_kernel_size - 1) // 2,
                ),
                torch.nn.Dropout(energy_embed_dropout_rate),
            )
            # define criterions
            self.prosody_criterion = prosody_criterions(
                use_masking=args.use_masking,
                use_weighted_masking=args.use_weighted_masking)

        self.taco2_loss = Tacotron2Loss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
        )

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)

        print('\n############## number of network parameters ##############\n')

        parameters = filter(lambda p: p.requires_grad, self.enc.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for Encoder: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad, self.dec.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for Decoder: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad,
                            self.duration_predictor.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for duration_predictor: %.5fM' %
              parameters)

        parameters = filter(lambda p: p.requires_grad,
                            self.pitch_predictor.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for pitch_predictor: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad,
                            self.energy_predictor.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for energy_predictor: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad,
                            self.pitch_embed.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for pitch_embed: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad,
                            self.energy_embed.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for energy_embed: %.5fM' % parameters)

        parameters = filter(lambda p: p.requires_grad, self.parameters())
        parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000
        print('Trainable Parameters for whole network: %.5fM' % parameters)

        print('\n##########################################################\n')
    def __init__(self, idim, odim, args):
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define transformer encoder
        if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(idim=idim,
                              embed_dim=args.embed_dim,
                              elayers=0,
                              econv_layers=args.eprenet_conv_layers,
                              econv_chans=args.eprenet_conv_chans,
                              econv_filts=args.eprenet_conv_filts,
                              use_batch_norm=args.use_batch_norm,
                              dropout_rate=args.eprenet_dropout_rate,
                              padding_idx=padding_idx),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim))
        else:
            encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                     embedding_dim=args.adim,
                                                     padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(idim=odim,
                              n_layers=args.dprenet_layers,
                              n_units=args.dprenet_units,
                              dropout_rate=args.dprenet_dropout_rate),
                torch.nn.Linear(args.dprenet_units, args.adim))
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # define loss function
        self.criterion = TransformerLoss(args)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                args.guided_attn_loss_sigma)

        # initialize parameters
        self._reset_parameters(args)
    def __init__(self, idim, odim, args=None):

        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = (
                    args.num_layers_applied_guided_attn)
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define transformer encoder
        '''if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(
                    idim=idim,
                    embed_dim=args.embed_dim,
                    elayers=0,
                    econv_layers=args.eprenet_conv_layers,
                    econv_chans=args.eprenet_conv_chans,
                    econv_filts=args.eprenet_conv_filts,
                    use_batch_norm=args.use_batch_norm,
                    dropout_rate=args.eprenet_dropout_rate,
                    padding_idx=padding_idx,
                ),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim),
            )
        else:
            encoder_input_layer = torch.nn.Embedding(
                num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx
            )'''

        # define projection layer
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(
                    idim=odim,
                    n_layers=args.dprenet_layers,
                    n_units=args.dprenet_units,
                    dropout_rate=args.dprenet_dropout_rate,
                ),
                torch.nn.Linear(args.dprenet_units, args.adim),
            )
        else:
            decoder_input_layer = "linear"
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=decoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
        )
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after,
        )

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = (None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate,
        ))

        # define loss function
        self.criterion = TransformerLoss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
            bce_pos_weight=args.bce_pos_weight,
        )
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )

        # initialize parameters
        self._reset_parameters(
            init_type=args.transformer_init,
            init_enc_alpha=args.initial_encoder_alpha,
            init_dec_alpha=args.initial_decoder_alpha,
        )

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)
示例#9
0
    def __init__(self, idim, odim, args=None):
        """Initialize Tacotron2 module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - embed_dim (int): Dimension of character embedding.
                - elayers (int): The number of encoder blstm layers.
                - eunits (int): The number of encoder blstm units.
                - econv_layers (int): The number of encoder conv layers.
                - econv_filts (int): The number of encoder conv filter size.
                - econv_chans (int): The number of encoder conv filter channels.
                - dlayers (int): The number of decoder lstm layers.
                - dunits (int): The number of decoder lstm units.
                - prenet_layers (int): The number of prenet layers.
                - prenet_units (int): The number of prenet units.
                - postnet_layers (int): The number of postnet layers.
                - postnet_filts (int): The number of postnet filter size.
                - postnet_chans (int): The number of postnet filter channels.
                - output_activation (int): The name of activation function for outputs.
                - adim (int): The number of dimension of mlp in attention.
                - aconv_chans (int): The number of attention conv filter channels.
                - aconv_filts (int): The number of attention conv filter size.
                - cumulate_att_w (bool): Whether to cumulate previous attention weight.
                - use_batch_norm (bool): Whether to use batch normalization.
                - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs.
                - dropout_rate (float): Dropout rate.
                - zoneout_rate (float): Zoneout rate.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimension.
                - spkidloss_weight (float): Weight for the speaker id module when combined to tts loss
                - num_spk (int): Number of speakers in the training data
                - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True).
                - use_cbhg (bool): Whether to use CBHG module.
                - cbhg_conv_bank_layers (int): The number of convoluional banks in CBHG.
                - cbhg_conv_bank_chans (int): The number of channels of convolutional bank in CBHG.
                - cbhg_proj_filts (int): The number of filter size of projection layeri in CBHG.
                - cbhg_proj_chans (int): The number of channels of projection layer in CBHG.
                - cbhg_highway_layers (int): The number of layers of highway network in CBHG.
                - cbhg_highway_units (int): The number of units of highway network in CBHG.
                - cbhg_gru_units (int): The number of units of GRU in CBHG.
                - use_masking (bool): Whether to mask padded part in loss calculation.
                - bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True).
                - use-guided-attn-loss (bool): Whether to use guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lamdba (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        self.spkidloss_weight = args.spkidloss_weight
        self.cumulate_att_w = args.cumulate_att_w
        self.reduction_factor = args.reduction_factor
        self.use_cbhg = args.use_cbhg
        self.use_guided_attn_loss = args.use_guided_attn_loss

        # define activation function for the final output
        if args.output_activation is None:
            self.output_activation_fn = None
        elif hasattr(F, args.output_activation):
            self.output_activation_fn = getattr(F, args.output_activation)
        else:
            raise ValueError('there is no such an activation function. (%s)' %
                             args.output_activation)

        # set padding idx
        padding_idx = 0

        # define network modules
        self.enc = Encoder(idim=idim,
                           embed_dim=args.embed_dim,
                           elayers=args.elayers,
                           eunits=args.eunits,
                           econv_layers=args.econv_layers,
                           econv_chans=args.econv_chans,
                           econv_filts=args.econv_filts,
                           use_batch_norm=args.use_batch_norm,
                           use_residual=args.use_residual,
                           dropout_rate=args.dropout_rate,
                           padding_idx=padding_idx)
        if args.train_spkid_extractor:
            self.train_spkid_extractor = True
            self.resnet_spkid = E2E_speakerid(input_dim=odim,
                                              output_dim=args.num_spk,
                                              Q=odim - 1,
                                              D=32,
                                              hidden_dim=args.spk_embed_dim,
                                              pooling='mean',
                                              network_type='lde',
                                              distance_type='sqr',
                                              asoftmax=True,
                                              resnet_AvgPool2d_fre_ksize=10)
            self.angle_loss = AngleLoss()
        else:
            self.train_spkid_extractor = False
        dec_idim = args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim
        if args.atype == "location":
            att = AttLoc(dec_idim, args.dunits, args.adim, args.aconv_chans,
                         args.aconv_filts)
        elif args.atype == "forward":
            att = AttForward(dec_idim, args.dunits, args.adim,
                             args.aconv_chans, args.aconv_filts)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "forward_ta":
            att = AttForwardTA(dec_idim, args.dunits, args.adim,
                               args.aconv_chans, args.aconv_filts, odim)
            if self.cumulate_att_w:
                logging.warning(
                    "cumulation of attention weights is disabled in forward attention."
                )
                self.cumulate_att_w = False
        elif args.atype == "noatt":  # This condition is satisfied only when using phone alignment for TTS input (Currently when using phn. ali. in TTS training for learning speaker embedding)
            att = None
        else:
            raise NotImplementedError("Support only location or forward")
        self.dec = Decoder(idim=dec_idim,
                           odim=odim,
                           att=att,
                           dlayers=args.dlayers,
                           dunits=args.dunits,
                           prenet_layers=args.prenet_layers,
                           prenet_units=args.prenet_units,
                           postnet_layers=args.postnet_layers,
                           postnet_chans=args.postnet_chans,
                           postnet_filts=args.postnet_filts,
                           output_activation_fn=self.output_activation_fn,
                           cumulate_att_w=self.cumulate_att_w,
                           use_batch_norm=args.use_batch_norm,
                           use_concate=args.use_concate,
                           dropout_rate=args.dropout_rate,
                           zoneout_rate=args.zoneout_rate,
                           reduction_factor=args.reduction_factor)
        self.taco2_loss = Tacotron2Loss(use_masking=args.use_masking,
                                        bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_loss = GuidedAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )
        if self.use_cbhg:
            self.cbhg = CBHG(idim=odim,
                             odim=args.spc_dim,
                             conv_bank_layers=args.cbhg_conv_bank_layers,
                             conv_bank_chans=args.cbhg_conv_bank_chans,
                             conv_proj_filts=args.cbhg_conv_proj_filts,
                             conv_proj_chans=args.cbhg_conv_proj_chans,
                             highway_layers=args.cbhg_highway_layers,
                             highway_units=args.cbhg_highway_units,
                             gru_units=args.cbhg_gru_units)
            self.cbhg_loss = CBHGLoss(use_masking=args.use_masking)