Esempio n. 1
0
    def __init__(self, idim, odim, args=None):
        """Initialize TTS-Transformer module.

        Args:
            idim (int): Dimension of the inputs.
            odim (int): Dimension of the outputs.
            args (Namespace, optional):
                - embed_dim (int): Dimension of character embedding.
                - eprenet_conv_layers (int): Number of encoder prenet convolution layers.
                - eprenet_conv_chans (int): Number of encoder prenet convolution channels.
                - eprenet_conv_filts (int): Filter size of encoder prenet convolution.
                - dprenet_layers (int): Number of decoder prenet layers.
                - dprenet_units (int): Number of decoder prenet hidden units.
                - elayers (int): Number of encoder layers.
                - eunits (int): Number of encoder hidden units.
                - adim (int): Number of attention transformation dimensions.
                - aheads (int): Number of heads for multi head attention.
                - dlayers (int): Number of decoder layers.
                - dunits (int): Number of decoder hidden units.
                - postnet_layers (int): Number of postnet layers.
                - postnet_chans (int): Number of postnet channels.
                - postnet_filts (int): Filter size of postnet.
                - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding.
                - use_batch_norm (bool): Whether to use batch normalization in encoder prenet.
                - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block.
                - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block.
                - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder.
                - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder.
                - reduction_factor (int): Reduction factor.
                - spk_embed_dim (int): Number of speaker embedding dimenstions.
                - spk_embed_integration_type: How to integrate speaker embedding.
                - transformer_init (float): How to initialize transformer parameters.
                - transformer_lr (float): Initial value of learning rate.
                - transformer_warmup_steps (int): Optimizer warmup steps.
                - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding.
                - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding.
                - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module.
                - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding.
                - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding.
                - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module.
                - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module.
                - eprenet_dropout_rate (float): Dropout rate in encoder prenet.
                - dprenet_dropout_rate (float): Dropout rate in decoder prenet.
                - postnet_dropout_rate (float): Dropout rate in postnet.
                - use_masking (bool): Whether to apply masking for padded part in loss calculation.
                - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation.
                - bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true).
                - loss_type (str): How to calculate loss.
                - use_guided_attn_loss (bool): Whether to use guided attention loss.
                - num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss.
                - num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss.
                - modules_applied_guided_attn (list): List of module names to apply guided attention loss.
                - guided-attn-loss-sigma (float) Sigma in guided attention loss.
                - guided-attn-loss-lambda (float): Lambda in guided attention loss.

        """
        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = args.num_layers_applied_guided_attn
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding

        # define transformer encoder
        if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(idim=idim,
                              embed_dim=args.embed_dim,
                              elayers=0,
                              econv_layers=args.eprenet_conv_layers,
                              econv_chans=args.eprenet_conv_chans,
                              econv_filts=args.eprenet_conv_filts,
                              use_batch_norm=args.use_batch_norm,
                              dropout_rate=args.eprenet_dropout_rate,
                              padding_idx=padding_idx),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim))
        else:
            encoder_input_layer = torch.nn.Embedding(num_embeddings=idim,
                                                     embedding_dim=args.adim,
                                                     padding_idx=padding_idx)
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=encoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
        )

        # define projection layer
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(idim=odim,
                              n_layers=args.dprenet_layers,
                              n_units=args.dprenet_units,
                              dropout_rate=args.dprenet_dropout_rate),
                torch.nn.Linear(args.dprenet_units, args.adim))
        else:
            decoder_input_layer = "linear"
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after)

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate)

        # define loss function
        self.criterion = TransformerLoss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
            bce_pos_weight=args.bce_pos_weight)
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )

        # initialize parameters
        self._reset_parameters(init_type=args.transformer_init,
                               init_enc_alpha=args.initial_encoder_alpha,
                               init_dec_alpha=args.initial_decoder_alpha)

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)
    def __init__(self, idim, odim, args=None):

        # initialize base classes
        TTSInterface.__init__(self)
        torch.nn.Module.__init__(self)

        # fill missing arguments
        args = fill_missing_args(args, self.add_arguments)

        # store hyperparameters
        self.idim = idim
        self.odim = odim
        self.spk_embed_dim = args.spk_embed_dim
        if self.spk_embed_dim is not None:
            self.spk_embed_integration_type = args.spk_embed_integration_type
        self.use_scaled_pos_enc = args.use_scaled_pos_enc
        self.reduction_factor = args.reduction_factor
        self.loss_type = args.loss_type
        self.use_guided_attn_loss = args.use_guided_attn_loss
        if self.use_guided_attn_loss:
            if args.num_layers_applied_guided_attn == -1:
                self.num_layers_applied_guided_attn = args.elayers
            else:
                self.num_layers_applied_guided_attn = (
                    args.num_layers_applied_guided_attn)
            if args.num_heads_applied_guided_attn == -1:
                self.num_heads_applied_guided_attn = args.aheads
            else:
                self.num_heads_applied_guided_attn = args.num_heads_applied_guided_attn
            self.modules_applied_guided_attn = args.modules_applied_guided_attn

        # use idx 0 as padding idx
        padding_idx = 0

        # get positional encoding class
        pos_enc_class = (ScaledPositionalEncoding
                         if self.use_scaled_pos_enc else PositionalEncoding)

        # define transformer encoder
        '''if args.eprenet_conv_layers != 0:
            # encoder prenet
            encoder_input_layer = torch.nn.Sequential(
                EncoderPrenet(
                    idim=idim,
                    embed_dim=args.embed_dim,
                    elayers=0,
                    econv_layers=args.eprenet_conv_layers,
                    econv_chans=args.eprenet_conv_chans,
                    econv_filts=args.eprenet_conv_filts,
                    use_batch_norm=args.use_batch_norm,
                    dropout_rate=args.eprenet_dropout_rate,
                    padding_idx=padding_idx,
                ),
                torch.nn.Linear(args.eprenet_conv_chans, args.adim),
            )
        else:
            encoder_input_layer = torch.nn.Embedding(
                num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx
            )'''

        # define projection layer
        if self.spk_embed_dim is not None:
            if self.spk_embed_integration_type == "add":
                self.projection = torch.nn.Linear(self.spk_embed_dim,
                                                  args.adim)
            else:
                self.projection = torch.nn.Linear(
                    args.adim + self.spk_embed_dim, args.adim)

        # define transformer decoder
        if args.dprenet_layers != 0:
            # decoder prenet
            decoder_input_layer = torch.nn.Sequential(
                DecoderPrenet(
                    idim=odim,
                    n_layers=args.dprenet_layers,
                    n_units=args.dprenet_units,
                    dropout_rate=args.dprenet_dropout_rate,
                ),
                torch.nn.Linear(args.dprenet_units, args.adim),
            )
        else:
            decoder_input_layer = "linear"
        self.encoder = Encoder(
            idim=idim,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.eunits,
            num_blocks=args.elayers,
            input_layer=decoder_input_layer,
            dropout_rate=args.transformer_enc_dropout_rate,
            positional_dropout_rate=args.
            transformer_enc_positional_dropout_rate,
            attention_dropout_rate=args.transformer_enc_attn_dropout_rate,
            pos_enc_class=pos_enc_class,
            normalize_before=args.encoder_normalize_before,
            concat_after=args.encoder_concat_after,
            positionwise_layer_type=args.positionwise_layer_type,
            positionwise_conv_kernel_size=args.positionwise_conv_kernel_size,
        )
        self.decoder = Decoder(
            odim=-1,
            attention_dim=args.adim,
            attention_heads=args.aheads,
            linear_units=args.dunits,
            num_blocks=args.dlayers,
            dropout_rate=args.transformer_dec_dropout_rate,
            positional_dropout_rate=args.
            transformer_dec_positional_dropout_rate,
            self_attention_dropout_rate=args.transformer_dec_attn_dropout_rate,
            src_attention_dropout_rate=args.
            transformer_enc_dec_attn_dropout_rate,
            input_layer=decoder_input_layer,
            use_output_layer=False,
            pos_enc_class=pos_enc_class,
            normalize_before=args.decoder_normalize_before,
            concat_after=args.decoder_concat_after,
        )

        # define final projection
        self.feat_out = torch.nn.Linear(args.adim,
                                        odim * args.reduction_factor)
        self.prob_out = torch.nn.Linear(args.adim, args.reduction_factor)

        # define postnet
        self.postnet = (None if args.postnet_layers == 0 else Postnet(
            idim=idim,
            odim=odim,
            n_layers=args.postnet_layers,
            n_chans=args.postnet_chans,
            n_filts=args.postnet_filts,
            use_batch_norm=args.use_batch_norm,
            dropout_rate=args.postnet_dropout_rate,
        ))

        # define loss function
        self.criterion = TransformerLoss(
            use_masking=args.use_masking,
            use_weighted_masking=args.use_weighted_masking,
            bce_pos_weight=args.bce_pos_weight,
        )
        if self.use_guided_attn_loss:
            self.attn_criterion = GuidedMultiHeadAttentionLoss(
                sigma=args.guided_attn_loss_sigma,
                alpha=args.guided_attn_loss_lambda,
            )

        # initialize parameters
        self._reset_parameters(
            init_type=args.transformer_init,
            init_enc_alpha=args.initial_encoder_alpha,
            init_dec_alpha=args.initial_decoder_alpha,
        )

        # load pretrained model
        if args.pretrained_model is not None:
            self.load_pretrained_model(args.pretrained_model)