def __init__( self, idim: int, odim: int, adim: int = 256, aheads: int = 2, elayers: int = 4, eunits: int = 1024, dlayers: int = 4, dunits: int = 1024, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, segment_size: int = 64, # hifigan generator generator_out_channels: int = 1, generator_channels: int = 512, generator_global_channels: int = -1, generator_kernel_size: int = 7, generator_upsample_scales: List[int] = [8, 8, 2, 2], generator_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], generator_resblock_kernel_sizes: List[int] = [3, 7, 11], generator_resblock_dilations: List[List[int]] = [ [1, 3, 5], [1, 3, 5], [1, 3, 5], ], generator_use_additional_convs: bool = True, generator_bias: bool = True, generator_nonlinear_activation: str = "LeakyReLU", generator_nonlinear_activation_params: Dict[str, Any] = { "negative_slope": 0.1 }, generator_use_weight_norm: bool = True, ): """Initialize JETS generator module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. zero_triu: Whether to use zero triu in relative self-attention module. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. pitch_predictor_layers (int): Number of pitch predictor layers. pitch_predictor_chans (int): Number of pitch predictor channels. pitch_predictor_kernel_size (int): Kernel size of pitch predictor. pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor. pitch_embed_kernel_size (float): Kernel size of pitch embedding. pitch_embed_dropout_rate (float): Dropout rate for pitch embedding. stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch predictor to encoder. energy_predictor_layers (int): Number of energy predictor layers. energy_predictor_chans (int): Number of energy predictor channels. energy_predictor_kernel_size (int): Kernel size of energy predictor. energy_predictor_dropout_rate (float): Dropout rate in energy predictor. energy_embed_kernel_size (float): Kernel size of energy embedding. energy_embed_dropout_rate (float): Dropout rate for energy embedding. stop_gradient_from_energy_predictor: Whether to stop gradient from energy predictor to encoder. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. segment_size (int): Segment size for random windowed discriminator generator_out_channels (int): Number of output channels. generator_channels (int): Number of hidden representation channels. generator_global_channels (int): Number of global conditioning channels. generator_kernel_size (int): Kernel size of initial and final conv layer. generator_upsample_scales (List[int]): List of upsampling scales. generator_upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. generator_resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. generator_resblock_dilations (List[List[int]]): List of list of dilations for residual blocks. generator_use_additional_convs (bool): Whether to use additional conv layers in residual blocks. generator_bias (bool): Whether to add bias parameter in convolution layers. generator_nonlinear_activation (str): Activation function module name. generator_nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation function. generator_use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super().__init__() self.segment_size = segment_size self.upsample_factor = int(np.prod(generator_upsample_scales)) self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define AlignmentModule self.alignment_module = AlignmentModule(adim, odim) # define length regulator self.length_regulator = GaussianUpsampling() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define hifigan generator self.generator = HiFiGANGenerator( in_channels=adim, out_channels=generator_out_channels, channels=generator_channels, global_channels=generator_global_channels, kernel_size=generator_kernel_size, upsample_scales=generator_upsample_scales, upsample_kernel_sizes=generator_upsample_kernel_sizes, resblock_kernel_sizes=generator_resblock_kernel_sizes, resblock_dilations=generator_resblock_dilations, use_additional_convs=generator_use_additional_convs, bias=generator_bias, nonlinear_activation=generator_nonlinear_activation, nonlinear_activation_params=generator_nonlinear_activation_params, use_weight_norm=generator_use_weight_norm, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, )
def __init__(self, idim, odim, args=None): # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.use_masking = args.use_masking self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # TODO(kan-bayashi): support reduction_factor > 1 if self.reduction_factor != 1: raise NotImplementedError("Support only reduction_factor = 1.") # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.duration_criterion = DurationPredictorLoss() # TODO(kan-bayashi): support knowledge distillation loss self.criterion = torch.nn.L1Loss()
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, postnet_dropout_rate: float = 0.5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, zero_triu: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. postnet_layers (int): Number of postnet layers. postnet_chans (int): Number of postnet channels. postnet_filts (int): Kernel size of postnet. postnet_dropout_rate (float): Dropout rate in postnet. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. zero_triu: Whether to use zero triu in relative self-attention module. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
class JETSGenerator(torch.nn.Module): """Generator module in JETS.""" def __init__( self, idim: int, odim: int, adim: int = 256, aheads: int = 2, elayers: int = 4, eunits: int = 1024, dlayers: int = 4, dunits: int = 1024, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, segment_size: int = 64, # hifigan generator generator_out_channels: int = 1, generator_channels: int = 512, generator_global_channels: int = -1, generator_kernel_size: int = 7, generator_upsample_scales: List[int] = [8, 8, 2, 2], generator_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], generator_resblock_kernel_sizes: List[int] = [3, 7, 11], generator_resblock_dilations: List[List[int]] = [ [1, 3, 5], [1, 3, 5], [1, 3, 5], ], generator_use_additional_convs: bool = True, generator_bias: bool = True, generator_nonlinear_activation: str = "LeakyReLU", generator_nonlinear_activation_params: Dict[str, Any] = { "negative_slope": 0.1 }, generator_use_weight_norm: bool = True, ): """Initialize JETS generator module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. zero_triu: Whether to use zero triu in relative self-attention module. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. pitch_predictor_layers (int): Number of pitch predictor layers. pitch_predictor_chans (int): Number of pitch predictor channels. pitch_predictor_kernel_size (int): Kernel size of pitch predictor. pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor. pitch_embed_kernel_size (float): Kernel size of pitch embedding. pitch_embed_dropout_rate (float): Dropout rate for pitch embedding. stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch predictor to encoder. energy_predictor_layers (int): Number of energy predictor layers. energy_predictor_chans (int): Number of energy predictor channels. energy_predictor_kernel_size (int): Kernel size of energy predictor. energy_predictor_dropout_rate (float): Dropout rate in energy predictor. energy_embed_kernel_size (float): Kernel size of energy embedding. energy_embed_dropout_rate (float): Dropout rate for energy embedding. stop_gradient_from_energy_predictor: Whether to stop gradient from energy predictor to encoder. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. segment_size (int): Segment size for random windowed discriminator generator_out_channels (int): Number of output channels. generator_channels (int): Number of hidden representation channels. generator_global_channels (int): Number of global conditioning channels. generator_kernel_size (int): Kernel size of initial and final conv layer. generator_upsample_scales (List[int]): List of upsampling scales. generator_upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. generator_resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. generator_resblock_dilations (List[List[int]]): List of list of dilations for residual blocks. generator_use_additional_convs (bool): Whether to use additional conv layers in residual blocks. generator_bias (bool): Whether to add bias parameter in convolution layers. generator_nonlinear_activation (str): Activation function module name. generator_nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation function. generator_use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super().__init__() self.segment_size = segment_size self.upsample_factor = int(np.prod(generator_upsample_scales)) self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define AlignmentModule self.alignment_module = AlignmentModule(adim, odim) # define length regulator self.length_regulator = GaussianUpsampling() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define hifigan generator self.generator = HiFiGANGenerator( in_channels=adim, out_channels=generator_out_channels, channels=generator_channels, global_channels=generator_global_channels, kernel_size=generator_kernel_size, upsample_scales=generator_upsample_scales, upsample_kernel_sizes=generator_upsample_kernel_sizes, resblock_kernel_sizes=generator_resblock_kernel_sizes, resblock_dilations=generator_resblock_dilations, use_additional_convs=generator_use_additional_convs, bias=generator_bias, nonlinear_activation=generator_nonlinear_activation, nonlinear_activation_params=generator_nonlinear_activation_params, use_weight_norm=generator_use_weight_norm, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, ]: """Calculate forward propagation. Args: text (Tensor): Text index tensor (B, T_text). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). pitch (Tensor): Batch of padded token-averaged pitch (B, T_text, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text). energy (Tensor): Batch of padded token-averaged energy (B, T_text, 1). energy_lengths (LongTensor): Batch of energy lengths (B, T_text). sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). Returns: Tensor: Waveform tensor (B, 1, segment_size * upsample_factor). Tensor: Binarization loss (). Tensor: Log probability attention matrix (B, T_feats, T_text). Tensor: Segments start index tensor (B,). Tensor: predicted duration (B, T_text). Tensor: ground-truth duration obtained from an alignment module (B, T_text). Tensor: predicted pitch (B, T_text,1). Tensor: ground-truth averaged pitch (B, T_text, 1). Tensor: predicted energy (B, T_text, 1). Tensor: ground-truth averaged energy (B, T_text, 1). """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel # forward encoder x_masks = self._source_mask(text_lengths) hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(feats) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward alignment module and obtain duration, averaged pitch, energy h_masks = make_pad_mask(text_lengths).to(hs.device) log_p_attn = self.alignment_module(hs, feats, h_masks) ds, bin_loss = viterbi_decode(log_p_attn, text_lengths, feats_lengths) ps = average_by_duration(ds, pitch.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) es = average_by_duration(ds, energy.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) # forward duration predictor and variance predictors if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), h_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) d_outs = self.duration_predictor(hs, h_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs # upsampling h_masks = make_non_pad_mask(feats_lengths).to(hs.device) d_masks = make_non_pad_mask(text_lengths).to(ds.device) hs = self.length_regulator(hs, ds, h_masks, d_masks) # (B, T_feats, adim) # forward decoder h_masks = self._source_mask(feats_lengths) zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) # get random segments z_segments, z_start_idxs = get_random_segments( zs.transpose(1, 2), feats_lengths, self.segment_size, ) # forward generator wav = self.generator(z_segments) return ( wav, bin_loss, log_p_attn, z_start_idxs, d_outs, ds, p_outs, ps, e_outs, es, ) def inference( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: Optional[torch.Tensor] = None, feats_lengths: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Run inference. Args: text (Tensor): Input text index tensor (B, T_text,). text_lengths (Tensor): Text length tensor (B,). feats (Tensor): Feature tensor (B, T_feats, aux_channels). feats_lengths (Tensor): Feature length tensor (B,). pitch (Tensor): Pitch tensor (B, T_feats, 1) energy (Tensor): Energy tensor (B, T_feats, 1) sids (Optional[Tensor]): Speaker index tensor (B,) or (B, 1). spembs (Optional[Tensor]): Speaker embedding tensor (B, spk_embed_dim). lids (Optional[Tensor]): Language index tensor (B,) or (B, 1). use_teacher_forcing (bool): Whether to use teacher forcing. Returns: Tensor: Generated waveform tensor (B, T_wav). Tensor: Duration tensor (B, T_text). """ # forward encoder x_masks = self._source_mask(text_lengths) hs, _ = self.encoder(text, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(feats) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) h_masks = make_pad_mask(text_lengths).to(hs.device) if use_teacher_forcing: # forward alignment module and obtain duration, averaged pitch, energy log_p_attn = self.alignment_module(hs, feats, h_masks) d_outs, _ = viterbi_decode(log_p_attn, text_lengths, feats_lengths) p_outs = average_by_duration(d_outs, pitch.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) e_outs = average_by_duration(d_outs, energy.squeeze(-1), text_lengths, feats_lengths).unsqueeze(-1) else: # forward duration predictor and variance predictors p_outs = self.pitch_predictor(hs, h_masks.unsqueeze(-1)) e_outs = self.energy_predictor(hs, h_masks.unsqueeze(-1)) d_outs = self.duration_predictor.inference(hs, h_masks) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs # upsampling if feats_lengths is not None: h_masks = make_non_pad_mask(feats_lengths).to(hs.device) else: h_masks = None d_masks = make_non_pad_mask(text_lengths).to(d_outs.device) hs = self.length_regulator(hs, d_outs, h_masks, d_masks) # (B, T_feats, adim) # forward decoder if feats_lengths is not None: h_masks = self._source_mask(feats_lengths) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) # forward generator wav = self.generator(zs.transpose(1, 2)) return wav.squeeze(1), d_outs def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, T_text, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, T_text, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
class FastSpeech2(AbsTTS): """FastSpeech2 module. This is a module of FastSpeech2 described in `FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and energy, we use token-averaged value introduced in `FastPitch: Parallel Text-to-speech with Pitch Prediction`_. .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: https://arxiv.org/abs/2006.04558 .. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`: https://arxiv.org/abs/2006.06873 """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, postnet_dropout_rate: float = 0.5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. postnet_layers (int): Number of postnet layers. postnet_chans (int): Number of postnet channels. postnet_filts (int): Kernel size of postnet. postnet_dropout_rate (float): Dropout rate in postnet. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. zero_triu: Whether to use zero triu in relative self-attention module. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. pitch_predictor_layers (int): Number of pitch predictor layers. pitch_predictor_chans (int): Number of pitch predictor channels. pitch_predictor_kernel_size (int): Kernel size of pitch predictor. pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor. pitch_embed_kernel_size (float): Kernel size of pitch embedding. pitch_embed_dropout_rate (float): Dropout rate for pitch embedding. stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch predictor to encoder. energy_predictor_layers (int): Number of energy predictor layers. energy_predictor_chans (int): Number of energy predictor channels. energy_predictor_kernel_size (int): Kernel size of energy predictor. energy_predictor_dropout_rate (float): Dropout rate in energy predictor. energy_embed_kernel_size (float): Kernel size of energy embedding. energy_embed_dropout_rate (float): Dropout rate for energy embedding. stop_gradient_from_energy_predictor: Whether to stop gradient from energy predictor to encoder. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, feats: torch.Tensor, feats_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, joint_training: bool = False, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded token ids (B, T_text). text_lengths (LongTensor): Batch of lengths of each input (B,). feats (Tensor): Batch of padded target features (B, T_feats, odim). feats_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, T_text + 1). durations_lengths (LongTensor): Batch of duration lengths (B, T_text + 1). pitch (Tensor): Batch of padded token-averaged pitch (B, T_text + 1, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, T_text + 1). energy (Tensor): Batch of padded token-averaged energy (B, T_text + 1, 1). energy_lengths (LongTensor): Batch of energy lengths (B, T_text + 1). spembs (Optional[Tensor]): Batch of speaker embeddings (B, spk_embed_dim). sids (Optional[Tensor]): Batch of speaker IDs (B, 1). lids (Optional[Tensor]): Batch of language IDs (B, 1). joint_training (bool): Whether to perform joint training with vocoder. Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value if not joint training else model outputs. """ text = text[:, :text_lengths.max()] # for data-parallel feats = feats[:, :feats_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds, ps, es = feats, durations, pitch, energy olens = feats_lengths # forward propagation before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, ys, olens, ds, ps, es, spembs=spembs, sids=sids, lids=lids, is_inference=False, ) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None # calculate loss l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, e_outs=e_outs, ys=ys, ds=ds, ps=ps, es=es, ilens=ilens, olens=olens, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss stats = dict( l1_loss=l1_loss.item(), duration_loss=duration_loss.item(), pitch_loss=pitch_loss.item(), energy_loss=energy_loss.item(), ) # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) if not joint_training: stats.update(loss=loss.item()) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight else: return loss, stats, after_outs if after_outs is not None else before_outs def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: Optional[torch.Tensor] = None, olens: Optional[torch.Tensor] = None, ds: Optional[torch.Tensor] = None, ps: Optional[torch.Tensor] = None, es: Optional[torch.Tensor] = None, spembs: Optional[torch.Tensor] = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, T_text, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate with SID and LID embeddings if self.spks is not None: sid_embs = self.sid_emb(sids.view(-1)) hs = hs + sid_embs.unsqueeze(1) if self.langs is not None: lid_embs = self.lid_emb(lids.view(-1)) hs = hs + lid_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, T_text) # use prediction in inference p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) # (B, T_feats, adim) else: d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, ds) # (B, T_feats, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, T_feats, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, T_feats, odim) # postnet -> (B, T_feats//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, p_outs, e_outs def inference( self, text: torch.Tensor, feats: Optional[torch.Tensor] = None, durations: Optional[torch.Tensor] = None, spembs: torch.Tensor = None, sids: Optional[torch.Tensor] = None, lids: Optional[torch.Tensor] = None, pitch: Optional[torch.Tensor] = None, energy: Optional[torch.Tensor] = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T_text,). feats (Optional[Tensor): Feature sequence to extract style (N, idim). durations (Optional[Tensor): Groundtruth of duration (T_text + 1,). spembs (Optional[Tensor): Speaker embedding vector (spk_embed_dim,). sids (Optional[Tensor]): Speaker ID (1,). lids (Optional[Tensor]): Language ID (1,). pitch (Optional[Tensor]): Groundtruth of token-avg pitch (T_text + 1, 1). energy (Optional[Tensor]): Groundtruth of token-avg energy (T_text + 1, 1). alpha (float): Alpha to control the speed. use_teacher_forcing (bool): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (T_feats, odim). * duration (Tensor): Duration sequence (T_text + 1,). * pitch (Tensor): Pitch sequence (T_text + 1,). * energy (Tensor): Energy sequence (T_text + 1,). """ x, y = text, feats spemb, d, p, e = spembs, durations, pitch, energy # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None if y is not None: ys = y.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) _, outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, ys, ds=ds, ps=ps, es=es, spembs=spembs, sids=sids, lids=lids, ) # (1, T_feats, odim) else: _, outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, ys, spembs=spembs, sids=sids, lids=lids, is_inference=True, alpha=alpha, ) # (1, T_feats, odim) return dict( feat_gen=outs[0], duration=d_outs[0], pitch=p_outs[0], energy=e_outs[0], ) def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, T_text, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, T_text, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
class Tacotron2_sa(TTSInterface, torch.nn.Module): @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument("--elayers", default=1, type=int, help="Number of encoder layers") group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # decoder group.add_argument("--dlayers", default=2, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1024, type=int, help="Number of decoder hidden units") group.add_argument("--prenet-layers", default=2, type=int, help="Number of prenet layers") group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument("--postnet-layers", default=5, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=512, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help= "Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument("--dropout-rate", default=0.5, type=float, help="Dropout rate") group.add_argument("--zoneout-rate", default=0.1, type=float, help="Zoneout rate") group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument("--spc-dim", default=None, type=int, help="Number of spectrogram dimensions") group.add_argument("--pretrained-model", default=None, type=str, help="Pretrained model path") # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) # duration predictor settings group.add_argument( "--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor", ) group.add_argument( "--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor", ) group.add_argument( "--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor", ) group.add_argument( "--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor", ) return parser def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError("there is no such an activation function. (%s)" % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, resume=args.encoder_resume, ) dec_idim = (args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) reduction = 'none' if args.use_weighted_masking else 'mean' self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim = 1 # pitch prediction pitch_predictor_layers = 2 pitch_predictor_chans = 384 pitch_predictor_kernel_size = 3 pitch_predictor_dropout_rate = 0.5 pitch_embed_kernel_size = 9 pitch_embed_dropout_rate = 0.5 self.stop_gradient_from_pitch_predictor = False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers = 2 energy_predictor_chans = 384 energy_predictor_kernel_size = 3 energy_predictor_dropout_rate = 0.5 energy_embed_kernel_size = 9 energy_embed_dropout_rate = 0.5 self.stop_gradient_from_energy_predictor = False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # define criterions self.prosody_criterion = prosody_criterions( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n') def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, new_ys=None, non_zero_lens_mask=None, ds_nonzeros=None, output_masks=None, position=None, f0=None, energy=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). new_ys (Tensor): reorganized mel-spectrograms non_zero_lens_masks (Tensor) ds_nonzeros (Tensor) output_masks (Tensor) position (Tenor): position values for each phoneme f0 (Tensor): pitch energy (Tensor) Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] # calculate FCL-taco2-enc outputs hs, hlens = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) # duration predictor loss cal ds = extras.squeeze(-1) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) duration_masks = make_non_pad_mask(ilens).to(ys.device) d_outs = d_outs.masked_select(duration_masks) duration_loss = self.duration_criterion( d_outs, ds.masked_select(duration_masks)) if self.use_fe_condition: expand_hs = hs fe_masks = d_masks if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor( expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 pitch_loss = self.prosody_criterion(p_outs, f0, ilens) energy_loss = self.prosody_criterion(e_outs, energy, ilens) p_embs = self.pitch_embed(f0.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(energy.transpose(1, 2)).transpose(1, 2) else: p_embs = None e_embs = None ylens = olens after_outs, before_outs = self.dec(hs, hlens, ds, ys, ylens, new_ys, non_zero_lens_mask, ds_nonzeros, output_masks, position, f0, energy, p_embs, e_embs) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_out = max(olens) ys = ys[:, :max_out] # caluculate taco2 loss l1_loss, mse_loss = self.taco2_loss(after_outs, before_outs, ys, olens) loss = l1_loss + mse_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "mse_loss": mse_loss.item() }, { "dur_loss": duration_loss.item() }, ] if self.use_fe_condition: prosody_weight = 1.0 loss = loss + prosody_weight * (pitch_loss + energy_loss) report_keys += [ { 'pitch_loss': pitch_loss.item() }, { 'energy_loss': energy_loss.item() }, ] report_keys += [{"loss": loss.item()}] self.reporter.report(report_keys) return loss def inference(self, x, inference_args, spemb=None, dur=None, f0=None, energy=None, utt_id=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). """ # inference h = self.enc.inference(x) # Tmax x h-dim if self.spk_embed_dim is not None: spemb = F.normalize(spemb, dim=0).unsqueeze(0).expand(h.size(0), -1) h = torch.cat([h, spemb], dim=-1) ilens = torch.LongTensor([h.shape[0]]).to(h.device) d_masks = make_pad_mask(ilens).to(h.device) if dur is not None: d_outs = dur.reshape(-1).long() else: d_outs = self.duration_predictor.inference(h.unsqueeze(0), d_masks) # B x Tmax d_outs = d_outs.squeeze(0).long() if self.use_fe_condition: if f0 is not None: p_outs = f0.unsqueeze(0) e_outs = energy.unsqueeze(0) else: expand_hs = h.unsqueeze(0) fe_masks = d_masks p_outs = self.pitch_predictor(expand_hs, fe_masks.unsqueeze(-1)) e_outs = self.energy_predictor(expand_hs, fe_masks.unsqueeze(-1)) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose( 1, 2).squeeze(0) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose( 1, 2).squeeze(0) else: p_outs = None e_outs = None p_embs = None e_embs = None if self.append_position: position = [] for iid in range(d_outs.shape[0]): if d_outs[iid] != 0: position.append( torch.FloatTensor(list(range(d_outs[iid].long()))) / d_outs[iid]) position = pad_list(position, 0) position = position.to(h.device) else: position = None outs = self.dec.inference( h, d_outs, position, p_embs, e_embs, ) return outs @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "mse_loss", "dur_loss"] if self.use_fe_condition: plot_keys += ["pitch_loss", "energy_loss"] return plot_keys
class FeedForwardTransformer(TTSInterface, torch.nn.Module): """Feed Forward Transformer for TTS a.k.a. FastSpeech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf """ @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group( "feed-forward transformer model setting") # network structure related group.add_argument( "--adim", default=384, type=int, help="Number of attention transformation dimensions") group.add_argument("--aheads", default=4, type=int, help="Number of heads for multi head attention") group.add_argument("--elayers", default=6, type=int, help="Number of encoder layers") group.add_argument("--eunits", default=1536, type=int, help="Number of encoder hidden units") group.add_argument("--dlayers", default=6, type=int, help="Number of decoder layers") group.add_argument("--dunits", default=1536, type=int, help="Number of decoder hidden units") group.add_argument("--positionwise-layer-type", default="linear", type=str, choices=["linear", "conv1d", "conv1d-linear"], help="Positionwise layer type.") group.add_argument("--positionwise-conv-kernel-size", default=3, type=int, help="Kernel size of positionwise conv1d layer") group.add_argument("--postnet-layers", default=0, type=int, help="Number of postnet layers") group.add_argument("--postnet-chans", default=256, type=int, help="Number of postnet channels") group.add_argument("--postnet-filts", default=5, type=int, help="Filter size of postnet") group.add_argument("--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization") group.add_argument( "--use-scaled-pos-enc", default=True, type=strtobool, help= "Use trainable scaled positional encoding instead of the fixed scale one" ) group.add_argument( "--encoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before encoder block") group.add_argument( "--decoder-normalize-before", default=False, type=strtobool, help="Whether to apply layer norm before decoder block") group.add_argument( "--encoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in encoder" ) group.add_argument( "--decoder-concat-after", default=False, type=strtobool, help= "Whether to concatenate attention layer's input and output in decoder" ) group.add_argument("--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor") group.add_argument("--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor") group.add_argument("--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor") group.add_argument("--teacher-model", default=None, type=str, nargs="?", help="Teacher model file path") group.add_argument("--reduction-factor", default=1, type=int, help="Reduction factor") group.add_argument("--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions") group.add_argument("--spk-embed-integration-type", type=str, default="add", choices=["add", "concat"], help="How to integrate speaker embedding") # training related group.add_argument("--transformer-init", type=str, default="pytorch", choices=[ "pytorch", "xavier_uniform", "xavier_normal", "kaiming_uniform", "kaiming_normal" ], help="How to initialize transformer parameters") group.add_argument( "--initial-encoder-alpha", type=float, default=1.0, help="Initial alpha value in encoder's ScaledPositionalEncoding") group.add_argument( "--initial-decoder-alpha", type=float, default=1.0, help="Initial alpha value in decoder's ScaledPositionalEncoding") group.add_argument("--transformer-lr", default=1.0, type=float, help="Initial value of learning rate") group.add_argument("--transformer-warmup-steps", default=4000, type=int, help="Optimizer warmup steps") group.add_argument( "--transformer-enc-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder except for attention") group.add_argument( "--transformer-enc-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder positional encoding") group.add_argument( "--transformer-enc-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder self-attention") group.add_argument( "--transformer-dec-dropout-rate", default=0.1, type=float, help= "Dropout rate for transformer decoder except for attention and pos encoding" ) group.add_argument( "--transformer-dec-positional-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder positional encoding") group.add_argument( "--transformer-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer decoder self-attention") group.add_argument( "--transformer-enc-dec-attn-dropout-rate", default=0.1, type=float, help="Dropout rate for transformer encoder-decoder attention") group.add_argument("--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor") group.add_argument("--postnet-dropout-rate", default=0.5, type=float, help="Dropout rate in postnet") group.add_argument("--transfer-encoder-from-teacher", default=True, type=strtobool, help="Whether to transfer teacher's parameters") group.add_argument( "--transferred-encoder-module", default="all", type=str, choices=["all", "embed"], help="Encoder modeules to be trasferred from teacher") # loss related group.add_argument( "--use-masking", default=True, type=strtobool, help="Whether to use masking in calculation of loss") group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss") return parser def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) def _forward(self, xs, ilens, ys=None, olens=None, spembs=None, ds=None, is_inference=False): # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens) # (B, Lmax, adim) else: if ds is None: with torch.no_grad(): ds = self.duration_calculator(xs, ilens, ys, olens, spembs) # (B, Tmax) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) if is_inference: return before_outs, after_outs, d_outs else: return before_outs, after_outs, ds, d_outs def forward(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] if extras is not None: extras = extras[:, :max(ilens)].squeeze(-1) # forward propagation before_outs, after_outs, ds, d_outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: l1_loss, duration_loss = self.criterion(None, before_outs, d_outs, ys, ds, ilens, olens) else: l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens) loss = l1_loss + duration_loss report_keys = [ { "l1_loss": l1_loss.item() }, { "duration_loss": duration_loss.item() }, { "loss": loss.item() }, ] # report extra information if self.use_scaled_pos_enc: report_keys += [ { "encoder_alpha": self.encoder.embed[-1].alpha.data.item() }, { "decoder_alpha": self.decoder.embed[-1].alpha.data.item() }, ] self.reporter.report(report_keys) return loss def calculate_all_attentions(self, xs, ilens, ys, olens, spembs=None, extras=None, *args, **kwargs): """Calculate all of the attention weights. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of precalculated durations (B, Tmax, 1). Returns: dict: Dict of attention weights and outputs. """ with torch.no_grad(): # remove unnecessary padded part (for multi-gpus) xs = xs[:, :max(ilens)] ys = ys[:, :max(olens)] if extras is not None: extras = extras[:, :max(ilens)].squeeze(-1) # forward propagation outs = self._forward(xs, ilens, ys, olens, spembs=spembs, ds=extras, is_inference=False)[1] att_ws_dict = dict() for name, m in self.named_modules(): if isinstance(m, MultiHeadedAttention): attn = m.attn.cpu().numpy() if "encoder" in name: attn = [a[:, :l, :l] for a, l in zip(attn, ilens.tolist())] elif "decoder" in name: if "src" in name: attn = [ a[:, :ol, :il] for a, il, ol in zip( attn, ilens.tolist(), olens.tolist()) ] elif "self" in name: attn = [ a[:, :l, :l] for a, l in zip(attn, olens.tolist()) ] else: logging.warning("unknown attention module: " + name) else: logging.warning("unknown attention module: " + name) att_ws_dict[name] = attn att_ws_dict["predicted_fbank"] = [ m[:l].T for m, l in zip(outs.cpu().numpy(), olens.tolist()) ] return att_ws_dict def inference(self, x, inference_args, spemb=None, *args, **kwargs): """Generate the sequence of features given the sequences of characters. Args: x (Tensor): Input sequence of characters (T,). inference_args (Namespace): Dummy for compatibility. spemb (Tensor, optional): Speaker embedding vector (spk_embed_dim). Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) else: spembs = None # inference _, outs, _ = self._forward(xs, ilens, spembs=spembs, is_inference=True) # (1, L, odim) return outs[0], None, None def _integrate_with_spk_embed(self, hs, spembs): """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim) """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens): """Make masks for self-attention. Args: ilens (LongTensor or List): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _load_teacher_model(self, model_path): # get teacher model config idim, odim, args = get_model_conf(model_path) # assert dimension is the same between teacher and studnet assert idim == self.idim assert odim == self.odim assert args.reduction_factor == self.reduction_factor # load teacher model from espnet.utils.dynamic_import import dynamic_import model_class = dynamic_import(args.model_module) model = model_class(idim, odim, args) torch_load(model_path, model) # freeze teacher model parameters for p in model.parameters(): p.requires_grad = False return model def _reset_parameters(self, init_type, init_enc_alpha=1.0, init_dec_alpha=1.0): # initialize parameters initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) def _transfer_from_teacher(self, transferred_encoder_module): if transferred_encoder_module == "all": for (n1, p1), (n2, p2) in zip(self.encoder.named_parameters(), self.teacher.encoder.named_parameters()): assert n1 == n2, "It seems that encoder structure is different." assert p1.shape == p2.shape, "It seems that encoder size is different." p1.data.copy_(p2.data) elif transferred_encoder_module == "embed": student_shape = self.encoder.embed[0].weight.data.shape teacher_shape = self.teacher.encoder.embed[0].weight.data.shape assert student_shape == teacher_shape, "It seems that embed dimension is different." self.encoder.embed[0].weight.data.copy_( self.teacher.encoder.embed[0].weight.data) else: raise NotImplementedError("Support only all or embed.") @property def attention_plot_class(self): """Return plot class for attention weight plot.""" return TTSPlot @property def base_plot_keys(self): """Return base key names to plot during training. keys should match what `chainer.reporter` reports. If you add the key `loss`, the reporter will report `main/loss` and `validation/main/loss` values. also `loss.png` will be created as a figure visulizing `main/loss` and `validation/main/loss` values. Returns: list: List of strings which are base keys to plot during training. """ plot_keys = ["loss", "l1_loss", "duration_loss"] if self.use_scaled_pos_enc: plot_keys += ["encoder_alpha", "decoder_alpha"] return plot_keys
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
class FastSpeech2(AbsTTS): """FastSpeech2 module. This is a module of FastSpeech2 described in `FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and energy, we use token-averaged value introduced in `FastPitch: Parallel Text-to-speech with Pitch Prediction`_. .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: https://arxiv.org/abs/2006.04558 .. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`: https://arxiv.org/abs/2006.06873 """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded token ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax + 1). durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, Tmax + 1). energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). energy_lengths (LongTensor): Batch of energy lengths (B, Tmax + 1). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds, ps, es = speech, durations, pitch, energy olens = speech_lengths # forward propagation before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None # calculate loss l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, e_outs=e_outs, ys=ys, ds=ds, ps=ps, es=es, ilens=ilens, olens=olens, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss stats = dict( l1_loss=l1_loss.item(), duration_loss=duration_loss.item(), pitch_loss=pitch_loss.item(), energy_loss=energy_loss.item(), loss=loss.item(), ) # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, ps: torch.Tensor = None, es: torch.Tensor = None, spembs: torch.Tensor = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) # use prediction in inference p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, ds) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, p_outs, e_outs def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, durations: torch.Tensor = None, pitch: torch.Tensor = None, energy: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ x, y = text, speech spemb, d, p, e = spembs, durations, pitch, energy # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None if y is not None: ys = y.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) _, outs, *_ = self._forward( xs, ilens, ys, ds=ds, ps=ps, es=es, spembs=spembs, ) # (1, L, odim) else: _, outs, *_ = self._forward( xs, ilens, ys, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) return outs[0], None, None def inference_pseudo( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, durations: torch.Tensor = None, pitch: torch.Tensor = None, energy: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ record = {23, 33, 70, 72, 76} def get(x, words=3): l = x.size(0) count = 0 b = [] i = 0 while i < l: if x[i].item() == 1 and i < l - 1 and x[i + 1].item() in record: count += 1 b.append(x[i + 1]) i += 2 elif x[i].item() == 1: count += 1 else: b.append(x[i]) if count == words: return torch.tensor(b).long().cuda(), torch.tensor([ x[j] for j in range(i, l) if x[j].item() != 1 ]).long().cuda() i += 1 x, y = text, speech spemb, d, p, e = spembs, durations, pitch, energy # add eos at the last of sequence # x = F.pad(x, [0, 1], "constant", self.eos) words = 3 buffer_len = 100 count = 0 diff = 0 buffer = torch.tensor([]).long().cuda() for chunk in x: # setup batch axis chunk, pseudo = get(chunk, words) start = buffer.size(0) diff = pseudo.size(0) buffer = torch.cat([buffer, chunk, pseudo]) end = buffer.size(0) - diff # print(start, end, diff, chunk.size(0), cur.size(0)) if buffer.size(0) > buffer_len: start -= buffer.size(0) - buffer_len end -= buffer.size(0) - buffer_len buffer = buffer[-buffer_len:] # print(start, end) buffer = F.pad(buffer, [0, 1], "constant", self.eos) xs, ys = buffer.unsqueeze(0), None ilens = torch.tensor([xs.shape[1]], dtype=torch.long, device=x[0].device) _, outs, pre, dur, *_ = self.inference_forward( xs, ilens, ys, start=start, end=end, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) buffer = buffer[:-diff - 1] import numpy as np count += 1 print(count, dur) np.save("/nolan/inference/{}_{:02d}.mel.npy".format("test", count), outs[pre:pre + dur, :].data.cpu().numpy()) # if count == 1: # wav = vocoder.inference(outs[pre:pre+dur, :]) # sf.write("/nolan/inference/{}.wav".format("test"), wav.data.cpu().numpy(), 22050, "PCM_16") # import time # print(time.time()) return outs[0], None, None def inference_( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, durations: torch.Tensor = None, pitch: torch.Tensor = None, energy: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ record = {23, 33, 70, 72, 76} def get(x, words=3, lookahead=1): l = x.size(0) count = 0 b = [] i = 0 while i < l: if x[i].item() == 1 and i < l - 1 and x[i + 1].item() in record: count += 1 b.append(x[i + 1]) i += 2 elif x[i].item() == 1: count += 1 else: b.append(x[i]) if count == words: if b and (len(b) != 1 or b[0] not in record): chunk_len = len(b) tmp_count = 0 j = i + 1 while j < l: if tmp_count == lookahead: break if x[j].item() == 1 and i < l - 1 and x[ j + 1].item() in record: tmp_count += 1 b.append(x[j + 1]) j += 2 elif x[j].item() == 1: tmp_count += 1 else: b.append(x[j]) j += 1 yield chunk_len, torch.tensor(b).long().cuda() b.clear() count = 0 i += 1 if b and (len(b) != 1 or b[0] not in record): chunk_len = len(b) yield chunk_len, torch.tensor(b).long().cuda() x, y = text, speech spemb, d, p, e = spembs, durations, pitch, energy # add eos at the last of sequence # x = F.pad(x, [0, 1], "constant", self.eos) words = 3 lookahead = 1 buffer = 100 count = 0 diff = 0 chunk = torch.tensor([]).long().cuda() for cur_len, cur in get(x, words, lookahead): # setup batch axis start = chunk.size(0) diff = cur.size(0) - cur_len chunk = torch.cat([chunk, cur]) end = chunk.size(0) - diff # print(start, end, diff, chunk.size(0), cur.size(0)) if chunk.size(0) > buffer: start -= chunk.size(0) - buffer end -= chunk.size(0) - buffer chunk = chunk[-buffer:] # print(start, end) chunk = F.pad(chunk, [0, 1], "constant", self.eos) xs, ys = chunk.unsqueeze(0), None ilens = torch.tensor([xs.shape[1]], dtype=torch.long, device=x.device) _, outs, pre, dur, *_ = self.inference_forward( xs, ilens, ys, start=start, end=end, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) chunk = chunk[:-diff - 1] import numpy as np count += 1 print(count, dur) np.save("/nolan/inference/{}_{:02d}.mel.npy".format("test", count), outs[pre:pre + dur, :].data.cpu().numpy()) # if count == 1: # wav = vocoder.inference(outs[pre:pre+dur, :]) # sf.write("/nolan/inference/{}.wav".format("test"), wav.data.cpu().numpy(), 22050, "PCM_16") # import time # print(time.time()) return outs[0], None, None def inference_forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, ps: torch.Tensor = None, es: torch.Tensor = None, spembs: torch.Tensor = None, is_inference: bool = False, alpha: float = 1.0, start=0, end=100, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) # use prediction in inference p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) # forward decoder h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) pre = torch.sum(d_outs[0, :start]).item() dur = torch.sum(d_outs[0, start:end]).item() return before_outs, after_outs.squeeze( 0), pre, dur, d_outs, p_outs, e_outs def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
class FastSpeech(AbsTTS): """FastSpeech module for end-to-end text-to-speech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int, optional): Number of encoder layers. eunits (int, optional): Number of encoder hidden units. dlayers (int, optional): Number of decoder layers. dunits (int, optional): Number of decoder hidden units. use_scaled_pos_enc (bool, optional): Whether to use trainable scaled positional encoding. encoder_normalize_before (bool, optional): Whether to perform layer normalization before encoder block. decoder_normalize_before (bool, optional): Whether to perform layer normalization before decoder block. encoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int, optional): Number of duration predictor layers. duration_predictor_chans (int, optional): Number of duration predictor channels. duration_predictor_kernel_size (int, optional): Kernel size of duration predictor. spk_embed_dim (int, optional): Number of speaker embedding dimensions. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str, optional): Whether to use global style token. gst_tokens (int, optional): The number of GST embeddings. gst_heads (int, optional): The number of heads in GST multihead attention. gst_conv_layers (int, optional): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int], optional): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. gst_conv_stride (int, optional): Stride size of conv layers in GST. gst_gru_layers (int, optional): The number of GRU layers in GST. gst_gru_units (int, optional): The number of GRU units in GST. reduction_factor (int, optional): Reduction factor. transformer_enc_dropout_rate (float, optional): Dropout rate in encoder except attention & positional encoding. transformer_enc_positional_dropout_rate (float, optional): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float, optional): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float, optional): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float, optional): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float, optional): Dropout rate in deocoder self-attention module. init_type (str, optional): How to initialize transformer parameters. init_enc_alpha (float, optional): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float, optional): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool, optional): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool, optional): Whether to apply weighted masking in loss calculation. """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, is_spk_layer_norm: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, hparams=None, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim self.hparams = hparams if self.hparams.is_multi_speakers: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder # print(idim) encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if self.hparams.is_multi_speakers: self.speaker_embedding = torch.nn.Embedding( hparams.n_speakers, self.spk_embed_dim) std = sqrt(2.0 / (hparams.n_speakers + self.spk_embed_dim)) val = sqrt(3.0) * std # uniform bounds for std self.speaker_embedding.weight.data.uniform_(-val, val) self.spkemb_projection = torch.nn.Linear(hparams.spk_embed_dim, hparams.spk_embed_dim) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, hparams=hparams) if self.hparams.style_embed_integration_type == "concat": self.gst_projection = torch.nn.Linear(adim + adim, adim) # define additional projection for speaker embedding if self.hparams.is_multi_speakers: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, hparams=hparams) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, spk_ids: torch.Tensor = None, style_ids: torch.Tensor = None, utt_mels: list = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # integrate speaker embedding if self.hparams.is_multi_speakers: # print('spk_ids.shape=',spk_ids.shape) spembs_ = self.spkemb_projection(self.speaker_embedding(spk_ids)) # forward encoder x_masks = self._source_mask(ilens) # print(spembs_) hs, _ = self.encoder(xs, x_masks, spembs_=spembs_) # (B, Tmax, adim) hs = self._integrate_with_spk_embed( hs, spembs_) if self.hparams.is_multi_speakers else hs if self.use_gst: if self.hparams.is_partial_refine and self.hparams.is_refine_style: style_embs = [] for i in range(len(utt_mels)): style_embs.append( self.gst(to_gpu(utt_mels[i]), spembs_=spembs_)) #(1, gst_token_dim) style_embs = torch.cat(style_embs, dim=0) #(17, gst_token_dim) style_tokens = style_embs.unsqueeze(0).expand( hs.size(0), -1, -1) #(B, 17, gst_token_dim) style_embs = self.gst.choosestl( hs, style_tokens) #(B, Tx, gst_token_dim) else: style_embs = self.gst(ys, spembs_=spembs_) # integrate with GST if self.hparams.style_embed_integration_type == "concat": hs = self.gst_projection( torch.cat([ hs, style_embs.unsqueeze(1).expand(-1, hs.size(1), -1) ], dim=-1)) print('spembs_.shape=', spembs_.shape) else: hs = hs + style_embs.unsqueeze(1) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference( hs, d_masks, spembs_=spembs_) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks, spembs_=spembs_) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks, spembs_=spembs_) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, spk_ids: torch.Tensor = None, style_ids: torch.Tensor = None, utt_mels: list = None ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax). durations_lengths (LongTensor): Batch of duration lengths (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel batch_size = text.size(0) xs = text ilens = text_lengths ys, ds = speech, durations olens = speech_lengths # forward propagation before_outs, after_outs, d_outs = self._forward(xs, ilens, ys, olens, ds, spk_ids=spk_ids, style_ids=style_ids, utt_mels=utt_mels, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None l1_loss, duration_loss, l2_loss = self.criterion( after_outs, before_outs, d_outs, ys, ds, ilens, olens) if self.hparams.use_ssim_loss: import pytorch_ssim ssim_loss = pytorch_ssim.SSIM() ssim_out = 3.0 * (2.0 - ssim_loss(before_outs, after_outs, ys, olens)) if self.hparams.loss_type == "L1": loss = l1_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) if self.hparams.loss_type == "L2": loss = l2_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) if self.hparams.loss_type == "L1_L2": loss = l1_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) + l2_loss stats = dict( L1=l1_loss.item() if self.hparams.loss_type == "L1" else 0, L2=l2_loss.item() if self.hparams.loss_type == "L2" else 0, L1_L2=l1_loss.item() + l2_loss.item() if self.hparams.loss_type == "L1_L2" else 0, duration_loss=duration_loss.item(), loss=loss.item(), ssim_loss=ssim_out.item() if self.hparams.use_ssim_loss else 0, ) # report extra information if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, after_outs, ys, olens def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spk_ids: torch.Tensor = None, durations: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, utt_mels: list = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spk_ids (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ x, y = text, speech spk_ids, d = spk_ids, durations # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None # if y is not None: # ys = y.unsqueeze(0) # if spk_ids is not None: # spk_ids = spk_ids.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds = d.unsqueeze(0) _, outs, *_ = self._forward( xs, ilens, ys, ds=ds, spk_ids=spk_ids, ) # (1, L, odim) else: # inference _, outs, _ = self._forward(xs, ilens, ys, spk_ids=spk_ids, is_inference=True, alpha=alpha, utt_mels=utt_mels) # (1, L, odim) return outs[0], None, None def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) @staticmethod def _parse_batch(batch, hparams, utt_mels=None): text_padded, input_lengths, mel_padded, output_lengths, duration_padded, duration_lengths = batch[: 6] text_padded = to_gpu(text_padded).long() input_lengths = to_gpu(input_lengths).long() mel_padded = to_gpu(mel_padded).float() output_lengths = to_gpu(output_lengths).long() duration_padded = to_gpu(duration_padded).long() duration_lengths = to_gpu(duration_lengths).long() idx = 6 speaker_ids = None style_ids = None utt_mels = utt_mels if hparams.is_multi_speakers: speaker_ids = batch[idx] speaker_ids = to_gpu(speaker_ids).long() idx += 1 if hparams.is_multi_styles: style_ids = batch[idx] style_ids = to_gpu(style_ids).long() #(B,) idx += 1 return (text_padded, input_lengths, mel_padded, output_lengths, duration_padded, duration_lengths, speaker_ids, style_ids, utt_mels)
def __init__( self, idim, odim, adim: int = 192, aheads: int = 2, elayers: int = 4, eunits: int = 768, dlayers: int = 2, dunits: int = 768, use_scaled_pos_enc: bool = True, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 192, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ddim: int = 64, beta_min: float = 0.05, beta_max: float = 20.0, pe_scale: int = 1000, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, ): super(GradTTS, self).__init__() self.idim = idim self.odim = odim self.eos = idim - 1 self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.padding_idx = 0 self.use_scaled_pos_enc = use_scaled_pos_enc encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) self.encoder = TransformerEncoder( idim=adim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) self.length_regulator = LengthRegulator() self.pre_decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) self.feat_out = torch.nn.Linear(adim, odim) self.decoder = Diffusion(ddim, beta_min, beta_max, pe_scale) self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) self.criterion = GradTTSLoss( odim, use_masking=use_masking, use_weighted_masking=use_weighted_masking, )
class GradTTS(AbsTTS): def __init__( self, idim, odim, adim: int = 192, aheads: int = 2, elayers: int = 4, eunits: int = 768, dlayers: int = 2, dunits: int = 768, use_scaled_pos_enc: bool = True, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 192, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ddim: int = 64, beta_min: float = 0.05, beta_max: float = 20.0, pe_scale: int = 1000, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, ): super(GradTTS, self).__init__() self.idim = idim self.odim = odim self.eos = idim - 1 self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.padding_idx = 0 self.use_scaled_pos_enc = use_scaled_pos_enc encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) self.encoder = TransformerEncoder( idim=adim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) self.length_regulator = LengthRegulator() self.pre_decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) self.feat_out = torch.nn.Linear(adim, odim) self.decoder = Diffusion(ddim, beta_min, beta_max, pe_scale) self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) self.criterion = GradTTSLoss( odim, use_masking=use_masking, use_weighted_masking=use_weighted_masking, ) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, pitch: torch.Tensor = None, pitch_lengths: torch.Tensor = None, energy: torch.Tensor = None, energy_lengths: torch.Tensor = None, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel batch_size = text.size(0) xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos # (B, Tmax, idim) ilens = text_lengths + 1 ys, ds, ps, es = speech, durations, pitch, energy olens = speech_lengths ys = ys.transpose(1, 2) # (B, odim, Lmax) if ys.size(2) % 4 != 0: ys = torch.cat([ ys, torch.zeros([batch_size, self.odim, 4 - ys.size(2) % 4], dtype=ys.dtype, device=ys.device) ], dim=2) noise_estimation, z, d_outs, p_outs, e_outs, mu, y_masks = self._forward( xs, ilens, ys, olens, ds, ps, es) prior_loss, duration_loss, diff_loss, pitch_loss, energy_loss = self.criterion( mu, noise_estimation, z, d_outs, p_outs, e_outs, ys, ds, ps, es, y_masks, ilens) loss = prior_loss + duration_loss + diff_loss + pitch_loss + energy_loss stats = dict( prior_loss=prior_loss.item(), duration_loss=duration_loss.item(), diff_loss=diff_loss.item(), pitch_loss=pitch_loss.item(), energy_loss=energy_loss.item(), loss=loss.item(), ) if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) stats.update( decoder_alpha=self.pre_decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, ps: torch.Tensor = None, es: torch.Tensor = None, ): x_masks = self._source_mask(ilens) # (B, 1, Tmax) y_masks = self._source_mask(olens) # (B, 1, Lmax) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs mu = self.length_regulator(hs, ds) # (B, Lmax, adim) mu, _ = self.pre_decoder(mu, y_masks) # (B, Lmax, adim) mu = self.feat_out(mu) # (B, Lmax, odim) mu = mu.transpose(1, 2) # (B, odim, Lmax) if mu.size(2) % 4 != 0: mu = torch.cat([ mu, torch.zeros([mu.size(0), self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device) ], dim=2) y_masks = torch.cat([ y_masks, torch.zeros([y_masks.size(0), 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device) ], dim=2) noise_estimation, z = self.decoder(ys, y_masks, mu) return noise_estimation, z, d_outs, p_outs, e_outs, mu, y_masks def inference( self, text: torch.Tensor, timesteps: int = 20, spembs: torch.Tensor = None, temperature: float = 1.0, alpha: float = 1.03, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: x = text x = F.pad(x, [0, 1], "constant", self.eos) ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs = x.unsqueeze(0) x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor.inference(hs, d_masks) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs mu = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, odim) length = mu.size(1) y_masks = torch.ones([1, 1, mu.size(1)], dtype=torch.int64, device=mu.device) mu, _ = self.pre_decoder(mu, y_masks) # (B, Lmax, adim) mu = self.feat_out(mu) # (B, Lmax, odim) mu = mu.transpose(1, 2) # (B, odim, Lmax) # import numpy as np # np.save("/nolan/inference/gradtts_pre.mel.npy", mu[0].transpose(0, 1).data.cpu().numpy()) if mu.size(2) % 4 != 0: mu = torch.cat([ mu, torch.zeros([1, self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device) ], dim=2) y_masks = torch.cat([ y_masks, torch.zeros([1, 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device) ], dim=2) z = mu + torch.randn_like(mu, device=mu.device) / temperature out = self.decoder.inference(z, y_masks, mu, timesteps, length).transpose(1, 2) return out[0, :length, :], None, None def decode_inference(self, mu, temperature=1.0, timesteps=10): length = mu.shape[0] olens = torch.tensor([length], dtype=torch.long, device=mu.device) y_masks = self._source_mask(olens) mu = mu.unsqueeze(0).transpose(1, 2) if mu.size(2) % 4 != 0: mu = torch.cat([ mu, torch.zeros([1, self.odim, 4 - mu.size(2) % 4], dtype=mu.dtype, device=mu.device) ], dim=2) y_masks = torch.cat([ y_masks, torch.zeros([1, 1, 4 - y_masks.size(2) % 4], dtype=y_masks.dtype, device=y_masks.device) ], dim=2) z = mu + torch.randn_like(mu, device=mu.device) / temperature out = self.decoder.inference(z, y_masks, mu, timesteps).transpose(1, 2) return out[0, :length, :] def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters( self, init_type: str, init_enc_alpha: float, init_dec_alpha: float, ): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.use_scaled_pos_enc: self.pre_decoder.embed[-1].alpha.data = torch.tensor( init_dec_alpha)
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, zero_triu: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError("there is no such an activation function. (%s)" % args.output_activation) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, resume=args.encoder_resume, ) dec_idim = (args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) reduction = 'none' if args.use_weighted_masking else 'mean' self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim = 1 # pitch prediction pitch_predictor_layers = 2 pitch_predictor_chans = 384 pitch_predictor_kernel_size = 3 pitch_predictor_dropout_rate = 0.5 pitch_embed_kernel_size = 9 pitch_embed_dropout_rate = 0.5 self.stop_gradient_from_pitch_predictor = False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers = 2 energy_predictor_chans = 384 energy_predictor_kernel_size = 3 energy_predictor_dropout_rate = 0.5 energy_embed_kernel_size = 9 energy_embed_dropout_rate = 0.5 self.stop_gradient_from_energy_predictor = False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # define criterions self.prosody_criterion = prosody_criterions( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n')
class FastSpeech(AbsTTS): """FastSpeech module for end-to-end text-to-speech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int, optional): Number of encoder layers. eunits (int, optional): Number of encoder hidden units. dlayers (int, optional): Number of decoder layers. dunits (int, optional): Number of decoder hidden units. use_scaled_pos_enc (bool, optional): Whether to use trainable scaled positional encoding. encoder_normalize_before (bool, optional): Whether to perform layer normalization before encoder block. decoder_normalize_before (bool, optional): Whether to perform layer normalization before decoder block. encoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int, optional): Number of duration predictor layers. duration_predictor_chans (int, optional): Number of duration predictor channels. duration_predictor_kernel_size (int, optional): Kernel size of duration predictor. spk_embed_dim (int, optional): Number of speaker embedding dimensions. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str, optional): Whether to use global style token. gst_tokens (int, optional): The number of GST embeddings. gst_heads (int, optional): The number of heads in GST multihead attention. gst_conv_layers (int, optional): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int], optional): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. gst_conv_stride (int, optional): Stride size of conv layers in GST. gst_gru_layers (int, optional): The number of GRU layers in GST. gst_gru_units (int, optional): The number of GRU units in GST. reduction_factor (int, optional): Reduction factor. transformer_enc_dropout_rate (float, optional): Dropout rate in encoder except attention & positional encoding. transformer_enc_positional_dropout_rate (float, optional): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float, optional): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float, optional): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float, optional): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float, optional): Dropout rate in deocoder self-attention module. init_type (str, optional): How to initialize transformer parameters. init_enc_alpha (float, optional): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float, optional): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool, optional): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool, optional): Whether to apply weighted masking in loss calculation. """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, zero_triu: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, spembs: torch.Tensor = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) hs = self.length_regulator(hs, ds) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax + 1). durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds = speech, durations olens = speech_lengths # forward propagation before_outs, after_outs, d_outs = self._forward(xs, ilens, ys, olens, ds, spembs=spembs, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None l1_loss, duration_loss = self.criterion(after_outs, before_outs, d_outs, ys, ds, ilens, olens) loss = l1_loss + duration_loss stats = dict( l1_loss=l1_loss.item(), duration_loss=duration_loss.item(), loss=loss.item(), ) # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, durations: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Dict[str, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Dict[str, Tensor]: Output dict including the following items: * feat_gen (Tensor): Output sequence of features (L, odim). * duration (Tensor): Duration sequence (T + 1,). """ x, y = text, speech spemb, d = spembs, durations # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None if y is not None: ys = y.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds = d.unsqueeze(0) _, outs, d_outs = self._forward( xs, ilens, ys, ds=ds, spembs=spembs, ) # (1, L, odim) else: # inference _, outs, d_outs = self._forward( xs, ilens, ys, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) return dict(feat_gen=outs[0], duration=d_outs[0]) def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
class FastSpeech2(AbsTTS): """FastSpeech2 module. This is a module of FastSpeech2 described in `FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`_. Instead of quantized pitch and energy, we use token-averaged value introduced in `FastPitch: Parallel Text-to-speech with Pitch Prediction`_. .. _`FastSpeech 2: Fast and High-Quality End-to-End Text to Speech`: https://arxiv.org/abs/2006.04558 .. _`FastPitch: Parallel Text-to-speech with Pitch Prediction`: https://arxiv.org/abs/2006.06873 """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, pitch: torch.Tensor, pitch_lengths: torch.Tensor, energy: torch.Tensor, energy_lengths: torch.Tensor, spembs: torch.Tensor = None, ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded token ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax + 1). durations_lengths (LongTensor): Batch of duration lengths (B, Tmax + 1). pitch (Tensor): Batch of padded token-averaged pitch (B, Tmax + 1, 1). pitch_lengths (LongTensor): Batch of pitch lengths (B, Tmax + 1). energy (Tensor): Batch of padded token-averaged energy (B, Tmax + 1, 1). energy_lengths (LongTensor): Batch of energy lengths (B, Tmax + 1). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel pitch = pitch[:, :pitch_lengths.max()] # for data-parallel energy = energy[:, :energy_lengths.max()] # for data-parallel batch_size = text.size(0) # Add eos at the last of sequence xs = F.pad(text, [0, 1], "constant", self.padding_idx) for i, l in enumerate(text_lengths): xs[i, l] = self.eos ilens = text_lengths + 1 ys, ds, ps, es = speech, durations, pitch, energy olens = speech_lengths # forward propagation before_outs, after_outs, d_outs, p_outs, e_outs = self._forward( xs, ilens, ys, olens, ds, ps, es, spembs=spembs, is_inference=False) # modify mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None # calculate loss l1_loss, duration_loss, pitch_loss, energy_loss = self.criterion( after_outs=after_outs, before_outs=before_outs, d_outs=d_outs, p_outs=p_outs, e_outs=e_outs, ys=ys, ds=ds, ps=ps, es=es, ilens=ilens, olens=olens, ) loss = l1_loss + duration_loss + pitch_loss + energy_loss stats = dict( l1_loss=l1_loss.item(), duration_loss=duration_loss.item(), pitch_loss=pitch_loss.item(), energy_loss=energy_loss.item(), loss=loss.item(), ) # report extra information if self.encoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), ) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: stats.update( decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, ps: torch.Tensor = None, es: torch.Tensor = None, spembs: torch.Tensor = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # forward encoder x_masks = self._source_mask(ilens) hs, _ = self.encoder(xs, x_masks) # (B, Tmax, adim) # integrate with GST if self.use_gst: style_embs = self.gst(ys) hs = hs + style_embs.unsqueeze(1) # integrate speaker embedding if self.spk_embed_dim is not None: hs = self._integrate_with_spk_embed(hs, spembs) # forward duration predictor and variance predictors d_masks = make_pad_mask(ilens).to(xs.device) if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(hs, d_masks.unsqueeze(-1)) if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(hs.detach(), d_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(hs, d_masks.unsqueeze(-1)) if is_inference: d_outs = self.duration_predictor.inference(hs, d_masks) # (B, Tmax) # use prediction in inference p_embs = self.pitch_embed(p_outs.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(e_outs.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, d_outs, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks) # use groundtruth in training p_embs = self.pitch_embed(ps.transpose(1, 2)).transpose(1, 2) e_embs = self.energy_embed(es.transpose(1, 2)).transpose(1, 2) hs = hs + e_embs + p_embs hs = self.length_regulator(hs, ds) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs, p_outs, e_outs def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spembs: torch.Tensor = None, durations: torch.Tensor = None, pitch: torch.Tensor = None, energy: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spembs (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). pitch (Tensor, optional): Groundtruth of token-averaged pitch (T + 1, 1). energy (Tensor, optional): Groundtruth of token-averaged energy (T + 1, 1). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ x, y = text, speech spemb, d, p, e = spembs, durations, pitch, energy # add eos at the last of sequence x = F.pad(x, [0, 1], "constant", self.eos) # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None if y is not None: ys = y.unsqueeze(0) if spemb is not None: spembs = spemb.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds, ps, es = d.unsqueeze(0), p.unsqueeze(0), e.unsqueeze(0) _, outs, *_ = self._forward( xs, ilens, ys, ds=ds, ps=ps, es=es, spembs=spembs, ) # (1, L, odim) else: _, outs, *_ = self._forward( xs, ilens, ys, spembs=spembs, is_inference=True, alpha=alpha, ) # (1, L, odim) return outs[0], None, None def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.encoder_type == "transformer" and self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) if self.decoder_type == "transformer" and self.use_scaled_pos_enc: self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha)
def __init__(self, idim, odim, args=None): """Initialize feed-forward Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - elayers (int): Number of encoder layers. - eunits (int): Number of encoder hidden units. - adim (int): Number of attention transformation dimensions. - aheads (int): Number of heads for multi head attention. - dlayers (int): Number of decoder layers. - dunits (int): Number of decoder hidden units. - use_scaled_pos_enc (bool): Whether to use trainable scaled positional encoding. - encoder_normalize_before (bool): Whether to perform layer normalization before encoder block. - decoder_normalize_before (bool): Whether to perform layer normalization before decoder block. - encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. - decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spk_embed_integration_type: How to integrate speaker embedding. - teacher_model (str): Teacher auto-regressive transformer model path. - reduction_factor (int): Reduction factor. - transformer_init (float): How to initialize transformer parameters. - transformer_lr (float): Initial value of learning rate. - transformer_warmup_steps (int): Optimizer warmup steps. - transformer_enc_dropout_rate (float): Dropout rate in encoder except attention & positional encoding. - transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. - transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. - transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. - transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. - transformer_dec_attn_dropout_rate (float): Dropout rate in deocoder self-attention module. - transformer_enc_dec_attn_dropout_rate (float): Dropout rate in encoder-deocoder attention module. - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - transfer_encoder_from_teacher: Whether to transfer encoder using teacher encoder parameters. - transferred_encoder_module: Encoder module to be initialized using teacher parameters. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = args.reduction_factor self.use_scaled_pos_enc = args.use_scaled_pos_enc self.spk_embed_dim = args.spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = args.spk_embed_integration_type # use idx 0 as padding idx padding_idx = 0 # get positional encoding class pos_enc_class = ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=args.adim, padding_idx=padding_idx) self.encoder = Encoder( idim=idim, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.eunits, num_blocks=args.elayers, input_layer=encoder_input_layer, dropout_rate=args.transformer_enc_dropout_rate, positional_dropout_rate=args. transformer_enc_positional_dropout_rate, attention_dropout_rate=args.transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.encoder_normalize_before, concat_after=args.encoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, args.adim) else: self.projection = torch.nn.Linear( args.adim + self.spk_embed_dim, args.adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=args.adim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=args.adim, attention_heads=args.aheads, linear_units=args.dunits, num_blocks=args.dlayers, input_layer=None, dropout_rate=args.transformer_dec_dropout_rate, positional_dropout_rate=args. transformer_dec_positional_dropout_rate, attention_dropout_rate=args.transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=args.decoder_normalize_before, concat_after=args.decoder_concat_after, positionwise_layer_type=args.positionwise_layer_type, positionwise_conv_kernel_size=args.positionwise_conv_kernel_size) # define final projection self.feat_out = torch.nn.Linear(args.adim, odim * args.reduction_factor) # define postnet self.postnet = None if args.postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=args.postnet_layers, n_chans=args.postnet_chans, n_filts=args.postnet_filts, use_batch_norm=args.use_batch_norm, dropout_rate=args.postnet_dropout_rate) # initialize parameters self._reset_parameters(init_type=args.transformer_init, init_enc_alpha=args.initial_encoder_alpha, init_dec_alpha=args.initial_decoder_alpha) # define teacher model if args.teacher_model is not None: self.teacher = self._load_teacher_model(args.teacher_model) else: self.teacher = None # define duration calculator if self.teacher is not None: self.duration_calculator = DurationCalculator(self.teacher) else: self.duration_calculator = None # transfer teacher parameters if self.teacher is not None and args.transfer_encoder_from_teacher: self._transfer_from_teacher(args.transferred_encoder_module) # define criterions self.criterion = FeedForwardTransformerLoss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking)
class Tacotron2_sa(TTSInterface, torch.nn.Module): @staticmethod def add_arguments(parser): """Add model-specific arguments to the parser.""" group = parser.add_argument_group("tacotron 2 model setting") # encoder group.add_argument( "--embed-dim", default=512, type=int, help="Number of dimension of embedding", ) group.add_argument( "--elayers", default=1, type=int, help="Number of encoder layers" ) group.add_argument( "--eunits", "-u", default=512, type=int, help="Number of encoder hidden units", ) group.add_argument( "--econv-layers", default=3, type=int, help="Number of encoder convolution layers", ) group.add_argument( "--econv-chans", default=512, type=int, help="Number of encoder convolution channels", ) group.add_argument( "--econv-filts", default=5, type=int, help="Filter size of encoder convolution", ) # decoder group.add_argument( "--dlayers", default=2, type=int, help="Number of decoder layers" ) group.add_argument( "--dunits", default=1024, type=int, help="Number of decoder hidden units" ) group.add_argument( "--prenet-layers", default=2, type=int, help="Number of prenet layers" ) group.add_argument( "--prenet-units", default=256, type=int, help="Number of prenet hidden units", ) group.add_argument( "--postnet-layers", default=5, type=int, help="Number of postnet layers" ) group.add_argument( "--postnet-chans", default=512, type=int, help="Number of postnet channels" ) group.add_argument( "--postnet-filts", default=5, type=int, help="Filter size of postnet" ) group.add_argument( "--output-activation", default=None, type=str, nargs="?", help="Output activation function", ) # model (parameter) related group.add_argument( "--use-batch-norm", default=True, type=strtobool, help="Whether to use batch normalization", ) group.add_argument( "--use-concate", default=True, type=strtobool, help="Whether to concatenate encoder embedding with decoder outputs", ) group.add_argument( "--use-residual", default=True, type=strtobool, help="Whether to use residual connection in conv layer", ) group.add_argument( "--dropout-rate", default=0.5, type=float, help="Dropout rate" ) group.add_argument( "--zoneout-rate", default=0.1, type=float, help="Zoneout rate" ) group.add_argument( "--reduction-factor", default=1, type=int, help="Reduction factor" ) group.add_argument( "--spk-embed-dim", default=None, type=int, help="Number of speaker embedding dimensions", ) group.add_argument( "--spc-dim", default=None, type=int, help="Number of spectrogram dimensions" ) group.add_argument( "--pretrained-model", default=None, type=str, help="Pretrained model path" ) # loss related group.add_argument( "--use-masking", default=False, type=strtobool, help="Whether to use masking in calculation of loss", ) group.add_argument( "--use-weighted-masking", default=False, type=strtobool, help="Whether to use weighted masking in calculation of loss", ) # duration predictor settings group.add_argument( "--duration-predictor-layers", default=2, type=int, help="Number of layers in duration predictor", ) group.add_argument( "--duration-predictor-chans", default=384, type=int, help="Number of channels in duration predictor", ) group.add_argument( "--duration-predictor-kernel-size", default=3, type=int, help="Kernel size in duration predictor", ) group.add_argument( "--duration-predictor-dropout-rate", default=0.1, type=float, help="Dropout rate for duration predictor", ) return parser def __init__(self, idim, odim, args=None, com_args=None): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. args (Namespace, optional): - spk_embed_dim (int): Dimension of the speaker embedding. - embed_dim (int): Dimension of character embedding. - elayers (int): The number of encoder blstm layers. - eunits (int): The number of encoder blstm units. - econv_layers (int): The number of encoder conv layers. - econv_filts (int): The number of encoder conv filter size. - econv_chans (int): The number of encoder conv filter channels. - dlayers (int): The number of decoder lstm layers. - dunits (int): The number of decoder lstm units. - prenet_layers (int): The number of prenet layers. - prenet_units (int): The number of prenet units. - postnet_layers (int): The number of postnet layers. - postnet_filts (int): The number of postnet filter size. - postnet_chans (int): The number of postnet filter channels. - output_activation (int): The name of activation function for outputs. - use_batch_norm (bool): Whether to use batch normalization. - use_concate (int): Whether to concatenate encoder embedding with decoder lstm outputs. - dropout_rate (float): Dropout rate. - zoneout_rate (float): Zoneout rate. - reduction_factor (int): Reduction factor. - spk_embed_dim (int): Number of speaker embedding dimenstions. - spc_dim (int): Number of spectrogram embedding dimenstions (only for use_cbhg=True) - use_masking (bool): Whether to apply masking for padded part in loss calculation. - use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. - duration_predictor_layers (int): Number of duration predictor layers. - duration_predictor_chans (int): Number of duration predictor channels. - duration_predictor_kernel_size (int): Kernel size of duration predictor. """ # initialize base classes TTSInterface.__init__(self) torch.nn.Module.__init__(self) # fill missing arguments args = fill_missing_args(args, self.add_arguments) args = vars(args) if 'use_fe_condition' not in args.keys(): args['use_fe_condition'] = com_args.use_fe_condition if 'append_position' not in args.keys(): args['append_position'] = com_args.append_position args = argparse.Namespace(**args) # store hyperparameters self.idim = idim self.odim = odim self.embed_dim = args.embed_dim self.spk_embed_dim = args.spk_embed_dim self.reduction_factor = args.reduction_factor self.use_fe_condition = args.use_fe_condition self.append_position = args.append_position # define activation function for the final output if args.output_activation is None: self.output_activation_fn = None elif hasattr(F, args.output_activation): self.output_activation_fn = getattr(F, args.output_activation) else: raise ValueError( "there is no such an activation function. (%s)" % args.output_activation ) # set padding idx padding_idx = 0 # define network modules self.enc = Encoder( idim=idim, embed_dim=args.embed_dim, elayers=args.elayers, eunits=args.eunits, econv_layers=args.econv_layers, econv_chans=args.econv_chans, econv_filts=args.econv_filts, use_batch_norm=args.use_batch_norm, use_residual=args.use_residual, dropout_rate=args.dropout_rate, padding_idx=padding_idx, is_student=False, ) dec_idim = ( args.eunits if args.spk_embed_dim is None else args.eunits + args.spk_embed_dim ) self.dec = Decoder( idim=dec_idim, odim=odim, dlayers=args.dlayers, dunits=args.dunits, prenet_layers=args.prenet_layers, prenet_units=args.prenet_units, postnet_layers=args.postnet_layers, postnet_chans=args.postnet_chans, postnet_filts=args.postnet_filts, output_activation_fn=self.output_activation_fn, use_batch_norm=args.use_batch_norm, use_concate=args.use_concate, dropout_rate=args.dropout_rate, zoneout_rate=args.zoneout_rate, reduction_factor=args.reduction_factor, use_fe_condition=args.use_fe_condition, append_position=args.append_position, is_student=False, ) self.duration_predictor = DurationPredictor( idim=dec_idim, n_layers=args.duration_predictor_layers, n_chans=args.duration_predictor_chans, kernel_size=args.duration_predictor_kernel_size, dropout_rate=args.duration_predictor_dropout_rate, ) # reduction = 'none' if args.use_weighted_masking else 'mean' # self.duration_criterion = DurationPredictorLoss(reduction=reduction) #-------------- picth/energy predictor definition ---------------# if self.use_fe_condition: output_dim=1 # pitch prediction pitch_predictor_layers=2 pitch_predictor_chans=384 pitch_predictor_kernel_size=3 pitch_predictor_dropout_rate=0.5 pitch_embed_kernel_size=9 pitch_embed_dropout_rate=0.5 self.stop_gradient_from_pitch_predictor=False self.pitch_predictor = VariancePredictor( idim=dec_idim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout_rate, output_dim=output_dim, ) self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size-1)//2, ), torch.nn.Dropout(pitch_embed_dropout_rate), ) # energy prediction energy_predictor_layers=2 energy_predictor_chans=384 energy_predictor_kernel_size=3 energy_predictor_dropout_rate=0.5 energy_embed_kernel_size=9 energy_embed_dropout_rate=0.5 self.stop_gradient_from_energy_predictor=False self.energy_predictor = VariancePredictor( idim=dec_idim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout_rate, output_dim=output_dim, ) self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=dec_idim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size-1)//2, ), torch.nn.Dropout(energy_embed_dropout_rate), ) # # define criterions # self.prosody_criterion = prosody_criterions( # use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking) self.taco2_loss = Tacotron2Loss( use_masking=args.use_masking, use_weighted_masking=args.use_weighted_masking, ) # load pretrained model if args.pretrained_model is not None: self.load_pretrained_model(args.pretrained_model) print('\n############## number of network parameters ##############\n') parameters = filter(lambda p: p.requires_grad, self.enc.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Encoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.dec.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for Decoder: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.duration_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for duration_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_predictor.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_predictor: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.pitch_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for pitch_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.energy_embed.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for energy_embed: %.5fM' % parameters) parameters = filter(lambda p: p.requires_grad, self.parameters()) parameters = sum([np.prod(p.size()) for p in parameters]) / 1_000_000 print('Trainable Parameters for whole network: %.5fM' % parameters) print('\n##########################################################\n') def forward( self, xs, ilens, ys, olens, spembs=None, extras=None, new_ys=None, non_zero_lens_mask=None, ds_nonzeros=None, output_masks=None, position=None, f0=None, energy=None, *args, **kwargs ): """Calculate forward propagation. Args: xs (Tensor): Batch of padded character ids (B, Tmax). ilens (LongTensor): Batch of lengths of each input batch (B,). ys (Tensor): Batch of padded target features (B, Lmax, odim). olens (LongTensor): Batch of the lengths of each target (B,). spembs (Tensor, optional): Batch of speaker embedding vectors (B, spk_embed_dim). extras (Tensor, optional): Batch of groundtruth spectrograms (B, Lmax, spc_dim). new_ys (Tensor): reorganized mel-spectrograms non_zero_lens_masks (Tensor) ds_nonzeros (Tensor) output_masks (Tensor) position (Tenor): position values for each phoneme f0 (Tensor): pitch energy (Tensor) Returns: Tensor: Loss value. """ # remove unnecessary padded part (for multi-gpus) max_in = max(ilens) max_out = max(olens) if max_in != xs.shape[1]: xs = xs[:, :max_in] if max_out != ys.shape[1]: ys = ys[:, :max_out] # calculate FCL-taco2-enc outputs hs, hlens, enc_distill_items = self.enc(xs, ilens) if self.spk_embed_dim is not None: spembs = F.normalize(spembs).unsqueeze(1).expand(-1, hs.size(1), -1) hs = torch.cat([hs, spembs], dim=-1) # duration predictor loss cal ds = extras.squeeze(-1) d_masks = make_pad_mask(ilens).to(xs.device) d_outs = self.duration_predictor(hs, d_masks) # (B, Tmax) d_outs = d_outs.unsqueeze(-1) # (B, Tmax, 1) # duration_masks = make_non_pad_mask(ilens).to(ys.device) # d_outs = d_outs.masked_select(duration_masks) # duration_loss = self.duration_criterion(d_outs, ds.masked_select(duration_masks)) if self.use_fe_condition: expand_hs = hs fe_masks = d_masks if self.stop_gradient_from_pitch_predictor: p_outs = self.pitch_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: p_outs = self.pitch_predictor(expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 if self.stop_gradient_from_energy_predictor: e_outs = self.energy_predictor(expand_hs.detach(), fe_masks.unsqueeze(-1)) else: e_outs = self.energy_predictor(expand_hs, fe_masks.unsqueeze(-1)) # B x Tmax x 1 # pitch_loss = self.prosody_criterion(p_outs,f0,ilens) # energy_loss = self.prosody_criterion(e_outs,energy,ilens) p_embs = self.pitch_embed(f0.transpose(1,2)).transpose(1,2) e_embs = self.energy_embed(energy.transpose(1,2)).transpose(1,2) else: p_embs = None e_embs = None ylens = olens after_outs, before_outs, dec_distill_items = self.dec(hs, hlens, ds, ys, ylens, new_ys, non_zero_lens_mask, ds_nonzeros, output_masks, position, p_embs, e_embs) prosody_distill_items = [d_outs, p_outs, e_outs, p_embs, e_embs] enc_distill_items = self.detach_items(enc_distill_items) dec_distill_items = self.detach_items(dec_distill_items) prosody_distill_items = self.detach_items(prosody_distill_items) return after_outs, before_outs, enc_distill_items, dec_distill_items, prosody_distill_items def detach_items(self, items): items = [it.detach() for it in items] return items