def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, elayers: int = 1, eunits: int = 512, econv_layers: int = 3, econv_chans: int = 512, econv_filts: int = 5, atype: str = "location", adim: int = 512, aconv_chans: int = 32, aconv_filts: int = 15, cumulate_att_w: bool = True, dlayers: int = 2, dunits: int = 1024, prenet_layers: int = 2, prenet_units: int = 256, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, output_activation: str = None, use_batch_norm: bool = True, use_concate: bool = True, use_residual: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "concat", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related dropout_rate: float = 0.5, zoneout_rate: float = 0.1, use_masking: bool = True, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1+L2", use_guided_attn_loss: bool = True, guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Tacotron2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.loss_type = loss_type if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError(f"there is no such an activation function. " f"({output_activation})") # set padding idx padding_idx = 0 self.padding_idx = padding_idx # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=eunits, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) if spk_embed_dim is None: dec_idim = eunits elif spk_embed_integration_type == "concat": dec_idim = eunits + spk_embed_dim elif spk_embed_integration_type == "add": dec_idim = eunits self.projection = torch.nn.Linear(self.spk_embed_dim, eunits) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning("cumulation of attention weights is disabled " "in forward attention.") self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, )
def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, dprenet_layers: int = 2, dprenet_units: int = 256, elayers: int = 6, eunits: int = 1024, adim: int = 512, aheads: int = 4, dlayers: int = 6, dunits: int = 1024, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, transformer_enc_dec_attn_dropout_rate: float = 0.1, eprenet_dropout_rate: float = 0.5, dprenet_dropout_rate: float = 0.5, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1", use_guided_attn_loss: bool = True, num_heads_applied_guided_attn: int = 2, num_layers_applied_guided_attn: int = 2, modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Transformer module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. embed_dim (int): Dimension of character embedding. eprenet_conv_layers (int): Number of encoder prenet convolution layers. eprenet_conv_chans (int): Number of encoder prenet convolution channels. eprenet_conv_filts (int): Filter size of encoder prenet convolution. dprenet_layers (int): Number of decoder prenet layers. dprenet_units (int): Number of decoder prenet hidden units. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. adim (int): Number of attention transformation dimensions. aheads (int): Number of heads for multi head attention. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. postnet_layers (int): Number of postnet layers. postnet_chans (int): Number of postnet channels. postnet_filts (int): Filter size of postnet. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. positionwise_layer_type (str): Position-wise operation type. positionwise_conv_kernel_size (int): Kernel size in position wise conv 1d. reduction_factor (int): Reduction factor. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type (str): How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): Number of GST embeddings. gst_heads (int): Number of heads in GST multihead attention. gst_conv_layers (int): Number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): Number of GRU layers in GST. gst_gru_units (int): Number of GRU units in GST. transformer_lr (float): Initial value of learning rate. transformer_warmup_steps (int): Optimizer warmup steps. transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. transformer_enc_dec_attn_dropout_rate (float): Dropout rate in source attention module. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. eprenet_dropout_rate (float): Dropout rate in encoder prenet. dprenet_dropout_rate (float): Dropout rate in decoder prenet. postnet_dropout_rate (float): Dropout rate in postnet. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Positive sample weight in bce calculation (only for use_masking=true). loss_type (str): How to calculate loss. use_guided_attn_loss (bool): Whether to use guided attention loss. num_heads_applied_guided_attn (int): Number of heads in each layer to apply guided attention loss. num_layers_applied_guided_attn (int): Number of layers to apply guided attention loss. modules_applied_guided_attn (Sequence[str]): List of module names to apply guided attention loss. guided_attn_loss_sigma (float) Sigma in guided attention loss. guided_attn_loss_lambda (float): Lambda in guided attention loss. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_type = loss_type self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define projection layer self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), torch.nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # define loss function self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, )
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, zero_triu: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def __init__( self, idim: int, odim: int, adim: int = 256, aheads: int = 2, elayers: int = 4, eunits: int = 1024, dlayers: int = 4, dunits: int = 1024, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, zero_triu: bool = False, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, segment_size: int = 64, # hifigan generator generator_out_channels: int = 1, generator_channels: int = 512, generator_global_channels: int = -1, generator_kernel_size: int = 7, generator_upsample_scales: List[int] = [8, 8, 2, 2], generator_upsample_kernel_sizes: List[int] = [16, 16, 4, 4], generator_resblock_kernel_sizes: List[int] = [3, 7, 11], generator_resblock_dilations: List[List[int]] = [ [1, 3, 5], [1, 3, 5], [1, 3, 5], ], generator_use_additional_convs: bool = True, generator_bias: bool = True, generator_nonlinear_activation: str = "LeakyReLU", generator_nonlinear_activation_params: Dict[str, Any] = { "negative_slope": 0.1 }, generator_use_weight_norm: bool = True, ): """Initialize JETS generator module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. zero_triu: Whether to use zero triu in relative self-attention module. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. pitch_predictor_layers (int): Number of pitch predictor layers. pitch_predictor_chans (int): Number of pitch predictor channels. pitch_predictor_kernel_size (int): Kernel size of pitch predictor. pitch_predictor_dropout_rate (float): Dropout rate in pitch predictor. pitch_embed_kernel_size (float): Kernel size of pitch embedding. pitch_embed_dropout_rate (float): Dropout rate for pitch embedding. stop_gradient_from_pitch_predictor: Whether to stop gradient from pitch predictor to encoder. energy_predictor_layers (int): Number of energy predictor layers. energy_predictor_chans (int): Number of energy predictor channels. energy_predictor_kernel_size (int): Kernel size of energy predictor. energy_predictor_dropout_rate (float): Dropout rate in energy predictor. energy_embed_kernel_size (float): Kernel size of energy embedding. energy_embed_dropout_rate (float): Dropout rate for energy embedding. stop_gradient_from_energy_predictor: Whether to stop gradient from energy predictor to encoder. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. segment_size (int): Segment size for random windowed discriminator generator_out_channels (int): Number of output channels. generator_channels (int): Number of hidden representation channels. generator_global_channels (int): Number of global conditioning channels. generator_kernel_size (int): Kernel size of initial and final conv layer. generator_upsample_scales (List[int]): List of upsampling scales. generator_upsample_kernel_sizes (List[int]): List of kernel sizes for upsample layers. generator_resblock_kernel_sizes (List[int]): List of kernel sizes for residual blocks. generator_resblock_dilations (List[List[int]]): List of list of dilations for residual blocks. generator_use_additional_convs (bool): Whether to use additional conv layers in residual blocks. generator_bias (bool): Whether to add bias parameter in convolution layers. generator_nonlinear_activation (str): Activation function module name. generator_nonlinear_activation_params (Dict[str, Any]): Hyperparameters for activation function. generator_use_weight_norm (bool): Whether to use weight norm. If set to true, it will be applied to all of the conv layers. """ super().__init__() self.segment_size = segment_size self.upsample_factor = int(np.prod(generator_upsample_scales)) self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, zero_triu=zero_triu, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define AlignmentModule self.alignment_module = AlignmentModule(adim, odim) # define length regulator self.length_regulator = GaussianUpsampling() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define hifigan generator self.generator = HiFiGANGenerator( in_channels=adim, out_channels=generator_out_channels, channels=generator_channels, global_channels=generator_global_channels, kernel_size=generator_kernel_size, upsample_scales=generator_upsample_scales, upsample_kernel_sizes=generator_upsample_kernel_sizes, resblock_kernel_sizes=generator_resblock_kernel_sizes, resblock_dilations=generator_resblock_dilations, use_additional_convs=generator_use_additional_convs, bias=generator_bias, nonlinear_activation=generator_nonlinear_activation, nonlinear_activation_params=generator_nonlinear_activation_params, use_weight_norm=generator_use_weight_norm, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, )
def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, elayers: int = 1, eunits: int = 512, econv_layers: int = 3, econv_chans: int = 512, econv_filts: int = 5, atype: str = "location", adim: int = 512, aconv_chans: int = 32, aconv_filts: int = 15, cumulate_att_w: bool = True, dlayers: int = 2, dunits: int = 1024, prenet_layers: int = 2, prenet_units: int = 256, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, output_activation: str = None, use_batch_norm: bool = True, use_concate: bool = True, use_residual: bool = False, reduction_factor: int = 1, # extra embedding related spks: int = -1, langs: int = -1, spk_embed_dim: int = None, spk_embed_integration_type: str = "concat", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related dropout_rate: float = 0.5, zoneout_rate: float = 0.1, use_masking: bool = True, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1+L2", use_guided_attn_loss: bool = True, guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Tacotron2 module. Args: idim (int): Dimension of the inputs. odim: (int) Dimension of the outputs. embed_dim (int): Dimension of the token embedding. elayers (int): Number of encoder blstm layers. eunits (int): Number of encoder blstm units. econv_layers (int): Number of encoder conv layers. econv_filts (int): Number of encoder conv filter size. econv_chans (int): Number of encoder conv filter channels. dlayers (int): Number of decoder lstm layers. dunits (int): Number of decoder lstm units. prenet_layers (int): Number of prenet layers. prenet_units (int): Number of prenet units. postnet_layers (int): Number of postnet layers. postnet_filts (int): Number of postnet filter size. postnet_chans (int): Number of postnet filter channels. output_activation (str): Name of activation function for outputs. adim (int): Number of dimension of mlp in attention. aconv_chans (int): Number of attention conv filter channels. aconv_filts (int): Number of attention conv filter size. cumulate_att_w (bool): Whether to cumulate previous attention weight. use_batch_norm (bool): Whether to use batch normalization. use_concate (bool): Whether to concat enc outputs w/ dec lstm outputs. reduction_factor (int): Reduction factor. spks: Number of speakers. If set to > 0, speaker ID embedding will be used. langs: Number of langs. If set to > 0, lang ID embedding will be used. spk_embed_dim (int): Pretrained speaker embedding dimension. spk_embed_integration_type (str): How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): Number of GST embeddings. gst_heads (int): Number of heads in GST multihead attention. gst_conv_layers (int): Number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): Number of GRU layers in GST. gst_gru_units (int): Number of GRU units in GST. dropout_rate (float): Dropout rate. zoneout_rate (float): Zoneout rate. use_masking (bool): Whether to mask padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. bce_pos_weight (float): Weight of positive sample of stop token (only for use_masking=True). loss_type (str): Loss function type ("L1", "L2", or "L1+L2"). use_guided_attn_loss (bool): Whether to use guided attention loss. guided_attn_loss_sigma (float): Sigma in guided attention loss. guided_attn_loss_lambda (float): Lambda in guided attention loss. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.cumulate_att_w = cumulate_att_w self.reduction_factor = reduction_factor self.spks = spks self.langs = langs self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.loss_type = loss_type if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # define activation function for the final output if output_activation is None: self.output_activation_fn = None elif hasattr(F, output_activation): self.output_activation_fn = getattr(F, output_activation) else: raise ValueError( f"there is no such an activation function. " f"({output_activation})" ) # set padding idx padding_idx = 0 self.padding_idx = padding_idx # define network modules self.enc = Encoder( idim=idim, embed_dim=embed_dim, elayers=elayers, eunits=eunits, econv_layers=econv_layers, econv_chans=econv_chans, econv_filts=econv_filts, use_batch_norm=use_batch_norm, use_residual=use_residual, dropout_rate=dropout_rate, padding_idx=padding_idx, ) if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=eunits, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) if self.spks > 0: self.sid_emb = torch.nn.Embedding(spks, embed_dim) if self.langs > 0: self.lid_emb = torch.nn.Embedding(langs, embed_dim) if spk_embed_dim is None: dec_idim = eunits elif spk_embed_integration_type == "concat": dec_idim = eunits + spk_embed_dim elif spk_embed_integration_type == "add": dec_idim = eunits self.projection = torch.nn.Linear(self.spk_embed_dim, eunits) else: raise ValueError(f"{spk_embed_integration_type} is not supported.") if atype == "location": att = AttLoc(dec_idim, dunits, adim, aconv_chans, aconv_filts) elif atype == "forward": att = AttForward(dec_idim, dunits, adim, aconv_chans, aconv_filts) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled " "in forward attention." ) self.cumulate_att_w = False elif atype == "forward_ta": att = AttForwardTA(dec_idim, dunits, adim, aconv_chans, aconv_filts, odim) if self.cumulate_att_w: logging.warning( "cumulation of attention weights is disabled " "in forward attention." ) self.cumulate_att_w = False else: raise NotImplementedError("Support only location or forward") self.dec = Decoder( idim=dec_idim, odim=odim, att=att, dlayers=dlayers, dunits=dunits, prenet_layers=prenet_layers, prenet_units=prenet_units, postnet_layers=postnet_layers, postnet_chans=postnet_chans, postnet_filts=postnet_filts, output_activation_fn=self.output_activation_fn, cumulate_att_w=self.cumulate_att_w, use_batch_norm=use_batch_norm, use_concate=use_concate, dropout_rate=dropout_rate, zoneout_rate=zoneout_rate, reduction_factor=reduction_factor, ) self.taco2_loss = Tacotron2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_loss = GuidedAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, )
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, postnet_dropout_rate: float = 0.5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, duration_predictor_dropout_rate: float = 0.1, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, # only for conformer conformer_rel_pos_type: str = "legacy", conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, zero_triu: bool = False, # extra embedding related spks: Optional[int] = None, langs: Optional[int] = None, spk_embed_dim: Optional[int] = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module. Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int): Number of encoder layers. eunits (int): Number of encoder hidden units. dlayers (int): Number of decoder layers. dunits (int): Number of decoder hidden units. postnet_layers (int): Number of postnet layers. postnet_chans (int): Number of postnet channels. postnet_filts (int): Kernel size of postnet. postnet_dropout_rate (float): Dropout rate in postnet. use_scaled_pos_enc (bool): Whether to use trainable scaled pos encoding. use_batch_norm (bool): Whether to use batch normalization in encoder prenet. encoder_normalize_before (bool): Whether to apply layernorm layer before encoder block. decoder_normalize_before (bool): Whether to apply layernorm layer before decoder block. encoder_concat_after (bool): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int): Number of duration predictor layers. duration_predictor_chans (int): Number of duration predictor channels. duration_predictor_kernel_size (int): Kernel size of duration predictor. duration_predictor_dropout_rate (float): Dropout rate in duration predictor. reduction_factor (int): Reduction factor. encoder_type (str): Encoder type ("transformer" or "conformer"). decoder_type (str): Decoder type ("transformer" or "conformer"). transformer_enc_dropout_rate (float): Dropout rate in encoder except attention and positional encoding. transformer_enc_positional_dropout_rate (float): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float): Dropout rate in decoder self-attention module. conformer_rel_pos_type (str): Relative pos encoding type in conformer. conformer_pos_enc_layer_type (str): Pos encoding layer type in conformer. conformer_self_attn_layer_type (str): Self-attention layer type in conformer conformer_activation_type (str): Activation function type in conformer. use_macaron_style_in_conformer: Whether to use macaron style FFN. use_cnn_in_conformer: Whether to use CNN in conformer. conformer_enc_kernel_size: Kernel size of encoder conformer. conformer_dec_kernel_size: Kernel size of decoder conformer. zero_triu: Whether to use zero triu in relative self-attention module. spks (Optional[int]): Number of speakers. If set to > 1, assume that the sids will be provided as the input and use sid embedding layer. langs (Optional[int]): Number of languages. If set to > 1, assume that the lids will be provided as the input and use sid embedding layer. spk_embed_dim (Optional[int]): Speaker embedding dimension. If set to > 0, assume that spembs will be provided as the input. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str): Whether to use global style token. gst_tokens (int): The number of GST embeddings. gst_heads (int): The number of heads in GST multihead attention. gst_conv_layers (int): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int]): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int): Kernel size of conv layers in GST. gst_conv_stride (int): Stride size of conv layers in GST. gst_gru_layers (int): The number of GRU layers in GST. gst_gru_units (int): The number of GRU units in GST. init_type (str): How to initialize transformer parameters. init_enc_alpha (float): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool): Whether to apply weighted masking in loss calculation. """ assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # check relative positional encoding compatibility if "conformer" in [encoder_type, decoder_type]: if conformer_rel_pos_type == "legacy": if conformer_pos_enc_layer_type == "rel_pos": conformer_pos_enc_layer_type = "legacy_rel_pos" logging.warning( "Fallback to conformer_pos_enc_layer_type = 'legacy_rel_pos' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") if conformer_self_attn_layer_type == "rel_selfattn": conformer_self_attn_layer_type = "legacy_rel_selfattn" logging.warning( "Fallback to " "conformer_self_attn_layer_type = 'legacy_rel_selfattn' " "due to the compatibility. If you want to use the new one, " "please use conformer_pos_enc_layer_type = 'latest'.") elif conformer_rel_pos_type == "latest": assert conformer_pos_enc_layer_type != "legacy_rel_pos" assert conformer_self_attn_layer_type != "legacy_rel_selfattn" else: raise ValueError( f"Unknown rel_pos_type: {conformer_rel_pos_type}") # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define spk and lang embedding self.spks = None if spks is not None and spks > 1: self.spks = spks self.sid_emb = torch.nn.Embedding(spks, adim) self.langs = None if langs is not None and langs > 1: self.langs = langs self.lid_emb = torch.nn.Embedding(langs, adim) # define additional projection for speaker embedding self.spk_embed_dim = None if spk_embed_dim is not None and spk_embed_dim > 0: self.spk_embed_dim = spk_embed_dim self.spk_embed_integration_type = spk_embed_integration_type if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, encoder_type: str = "transformer", decoder_type: str = "transformer", # only for conformer conformer_pos_enc_layer_type: str = "rel_pos", conformer_self_attn_layer_type: str = "rel_selfattn", conformer_activation_type: str = "swish", use_macaron_style_in_conformer: bool = True, use_cnn_in_conformer: bool = True, conformer_enc_kernel_size: int = 7, conformer_dec_kernel_size: int = 31, # duration predictor duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, # energy predictor energy_predictor_layers: int = 2, energy_predictor_chans: int = 384, energy_predictor_kernel_size: int = 3, energy_predictor_dropout: float = 0.5, energy_embed_kernel_size: int = 9, energy_embed_dropout: float = 0.5, stop_gradient_from_energy_predictor: bool = False, # pitch predictor pitch_predictor_layers: int = 2, pitch_predictor_chans: int = 384, pitch_predictor_kernel_size: int = 3, pitch_predictor_dropout: float = 0.5, pitch_embed_kernel_size: int = 9, pitch_embed_dropout: float = 0.5, stop_gradient_from_pitch_predictor: bool = False, # pretrained spk emb spk_embed_dim: int = None, spk_embed_integration_type: str = "add", # GST use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech2 module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.encoder_type = encoder_type self.decoder_type = decoder_type self.stop_gradient_from_pitch_predictor = stop_gradient_from_pitch_predictor self.stop_gradient_from_energy_predictor = stop_gradient_from_energy_predictor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if encoder_type == "transformer": self.encoder = TransformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif encoder_type == "conformer": self.encoder = ConformerEncoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_enc_kernel_size, ) else: raise ValueError(f"{encoder_type} is not supported.") # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define pitch predictor self.pitch_predictor = VariancePredictor( idim=adim, n_layers=pitch_predictor_layers, n_chans=pitch_predictor_chans, kernel_size=pitch_predictor_kernel_size, dropout_rate=pitch_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous pitch + FastPitch style avg self.pitch_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=pitch_embed_kernel_size, padding=(pitch_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(pitch_embed_dropout), ) # define energy predictor self.energy_predictor = VariancePredictor( idim=adim, n_layers=energy_predictor_layers, n_chans=energy_predictor_chans, kernel_size=energy_predictor_kernel_size, dropout_rate=energy_predictor_dropout, ) # NOTE(kan-bayashi): We use continuous enegy + FastPitch style avg self.energy_embed = torch.nn.Sequential( torch.nn.Conv1d( in_channels=1, out_channels=adim, kernel_size=energy_embed_kernel_size, padding=(energy_embed_kernel_size - 1) // 2, ), torch.nn.Dropout(energy_embed_dropout), ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder if decoder_type == "transformer": self.decoder = TransformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) elif decoder_type == "conformer": self.decoder = ConformerEncoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, macaron_style=use_macaron_style_in_conformer, pos_enc_layer_type=conformer_pos_enc_layer_type, selfattention_layer_type=conformer_self_attn_layer_type, activation_type=conformer_activation_type, use_cnn_module=use_cnn_in_conformer, cnn_module_kernel=conformer_dec_kernel_size, ) else: raise ValueError(f"{decoder_type} is not supported.") # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeech2Loss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
def __init__( self, # network structure related idim: int, odim: int, embed_dim: int = 512, eprenet_conv_layers: int = 3, eprenet_conv_chans: int = 256, eprenet_conv_filts: int = 5, dprenet_layers: int = 2, dprenet_units: int = 256, elayers: int = 6, eunits: int = 1024, adim: int = 512, aheads: int = 4, dlayers: int = 6, dunits: int = 1024, postnet_layers: int = 5, postnet_chans: int = 256, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = True, decoder_normalize_before: bool = True, encoder_concat_after: bool = False, decoder_concat_after: bool = False, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, transformer_enc_dec_attn_dropout_rate: float = 0.1, eprenet_dropout_rate: float = 0.5, dprenet_dropout_rate: float = 0.5, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, bce_pos_weight: float = 5.0, loss_type: str = "L1", use_guided_attn_loss: bool = True, num_heads_applied_guided_attn: int = 2, num_layers_applied_guided_attn: int = 2, modules_applied_guided_attn: Sequence[str] = ("encoder-decoder"), guided_attn_loss_sigma: float = 0.4, guided_attn_loss_lambda: float = 1.0, ): """Initialize Transformer module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.spk_embed_dim = spk_embed_dim self.reduction_factor = reduction_factor self.use_gst = use_gst self.use_guided_attn_loss = use_guided_attn_loss self.use_scaled_pos_enc = use_scaled_pos_enc self.loss_type = loss_type self.use_guided_attn_loss = use_guided_attn_loss if self.use_guided_attn_loss: if num_layers_applied_guided_attn == -1: self.num_layers_applied_guided_attn = elayers else: self.num_layers_applied_guided_attn = num_layers_applied_guided_attn if num_heads_applied_guided_attn == -1: self.num_heads_applied_guided_attn = aheads else: self.num_heads_applied_guided_attn = num_heads_applied_guided_attn self.modules_applied_guided_attn = modules_applied_guided_attn if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = ( ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding ) # define transformer encoder if eprenet_conv_layers != 0: # encoder prenet encoder_input_layer = torch.nn.Sequential( EncoderPrenet( idim=idim, embed_dim=embed_dim, elayers=0, econv_layers=eprenet_conv_layers, econv_chans=eprenet_conv_chans, econv_filts=eprenet_conv_filts, use_batch_norm=use_batch_norm, dropout_rate=eprenet_dropout_rate, padding_idx=self.padding_idx, ), torch.nn.Linear(eprenet_conv_chans, adim), ) else: encoder_input_layer = torch.nn.Embedding( num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx ) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define projection layer if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define transformer decoder if dprenet_layers != 0: # decoder prenet decoder_input_layer = torch.nn.Sequential( DecoderPrenet( idim=odim, n_layers=dprenet_layers, n_units=dprenet_units, dropout_rate=dprenet_dropout_rate, ), torch.nn.Linear(dprenet_units, adim), ) else: decoder_input_layer = "linear" self.decoder = Decoder( odim=odim, # odim is needed when no prenet is used attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, self_attention_dropout_rate=transformer_dec_attn_dropout_rate, src_attention_dropout_rate=transformer_enc_dec_attn_dropout_rate, input_layer=decoder_input_layer, use_output_layer=False, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) self.prob_out = torch.nn.Linear(adim, reduction_factor) # define postnet self.postnet = ( None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, ) ) # define loss function self.criterion = TransformerLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking, bce_pos_weight=bce_pos_weight, ) if self.use_guided_attn_loss: self.attn_criterion = GuidedMultiHeadAttentionLoss( sigma=guided_attn_loss_sigma, alpha=guided_attn_loss_lambda, ) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, )
def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.eos = idim - 1 self.reduction_factor = reduction_factor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim if self.spk_embed_dim is not None: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, ) # define additional projection for speaker embedding if self.spk_embed_dim is not None: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, ) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, ) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking)
class FastSpeech(AbsTTS): """FastSpeech module for end-to-end text-to-speech. This is a module of FastSpeech, feed-forward Transformer with duration predictor described in `FastSpeech: Fast, Robust and Controllable Text to Speech`_, which does not require any auto-regressive processing during inference, resulting in fast decoding compared with auto-regressive Transformer. .. _`FastSpeech: Fast, Robust and Controllable Text to Speech`: https://arxiv.org/pdf/1905.09263.pdf Args: idim (int): Dimension of the inputs. odim (int): Dimension of the outputs. elayers (int, optional): Number of encoder layers. eunits (int, optional): Number of encoder hidden units. dlayers (int, optional): Number of decoder layers. dunits (int, optional): Number of decoder hidden units. use_scaled_pos_enc (bool, optional): Whether to use trainable scaled positional encoding. encoder_normalize_before (bool, optional): Whether to perform layer normalization before encoder block. decoder_normalize_before (bool, optional): Whether to perform layer normalization before decoder block. encoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in encoder. decoder_concat_after (bool, optional): Whether to concatenate attention layer's input and output in decoder. duration_predictor_layers (int, optional): Number of duration predictor layers. duration_predictor_chans (int, optional): Number of duration predictor channels. duration_predictor_kernel_size (int, optional): Kernel size of duration predictor. spk_embed_dim (int, optional): Number of speaker embedding dimensions. spk_embed_integration_type: How to integrate speaker embedding. use_gst (str, optional): Whether to use global style token. gst_tokens (int, optional): The number of GST embeddings. gst_heads (int, optional): The number of heads in GST multihead attention. gst_conv_layers (int, optional): The number of conv layers in GST. gst_conv_chans_list: (Sequence[int], optional): List of the number of channels of conv layers in GST. gst_conv_kernel_size (int, optional): Kernal size of conv layers in GST. gst_conv_stride (int, optional): Stride size of conv layers in GST. gst_gru_layers (int, optional): The number of GRU layers in GST. gst_gru_units (int, optional): The number of GRU units in GST. reduction_factor (int, optional): Reduction factor. transformer_enc_dropout_rate (float, optional): Dropout rate in encoder except attention & positional encoding. transformer_enc_positional_dropout_rate (float, optional): Dropout rate after encoder positional encoding. transformer_enc_attn_dropout_rate (float, optional): Dropout rate in encoder self-attention module. transformer_dec_dropout_rate (float, optional): Dropout rate in decoder except attention & positional encoding. transformer_dec_positional_dropout_rate (float, optional): Dropout rate after decoder positional encoding. transformer_dec_attn_dropout_rate (float, optional): Dropout rate in deocoder self-attention module. init_type (str, optional): How to initialize transformer parameters. init_enc_alpha (float, optional): Initial value of alpha in scaled pos encoding of the encoder. init_dec_alpha (float, optional): Initial value of alpha in scaled pos encoding of the decoder. use_masking (bool, optional): Whether to apply masking for padded part in loss calculation. use_weighted_masking (bool, optional): Whether to apply weighted masking in loss calculation. """ def __init__( self, # network structure related idim: int, odim: int, adim: int = 384, aheads: int = 4, elayers: int = 6, eunits: int = 1536, dlayers: int = 6, dunits: int = 1536, postnet_layers: int = 5, postnet_chans: int = 512, postnet_filts: int = 5, positionwise_layer_type: str = "conv1d", positionwise_conv_kernel_size: int = 1, use_scaled_pos_enc: bool = True, use_batch_norm: bool = True, encoder_normalize_before: bool = False, decoder_normalize_before: bool = False, is_spk_layer_norm: bool = False, encoder_concat_after: bool = False, decoder_concat_after: bool = False, duration_predictor_layers: int = 2, duration_predictor_chans: int = 384, duration_predictor_kernel_size: int = 3, reduction_factor: int = 1, spk_embed_dim: int = None, spk_embed_integration_type: str = "add", use_gst: bool = False, gst_tokens: int = 10, gst_heads: int = 4, gst_conv_layers: int = 6, gst_conv_chans_list: Sequence[int] = (32, 32, 64, 64, 128, 128), gst_conv_kernel_size: int = 3, gst_conv_stride: int = 2, gst_gru_layers: int = 1, gst_gru_units: int = 128, # training related transformer_enc_dropout_rate: float = 0.1, transformer_enc_positional_dropout_rate: float = 0.1, transformer_enc_attn_dropout_rate: float = 0.1, transformer_dec_dropout_rate: float = 0.1, transformer_dec_positional_dropout_rate: float = 0.1, transformer_dec_attn_dropout_rate: float = 0.1, duration_predictor_dropout_rate: float = 0.1, postnet_dropout_rate: float = 0.5, hparams=None, init_type: str = "xavier_uniform", init_enc_alpha: float = 1.0, init_dec_alpha: float = 1.0, use_masking: bool = False, use_weighted_masking: bool = False, ): """Initialize FastSpeech module.""" assert check_argument_types() super().__init__() # store hyperparameters self.idim = idim self.odim = odim self.reduction_factor = reduction_factor self.use_scaled_pos_enc = use_scaled_pos_enc self.use_gst = use_gst self.spk_embed_dim = spk_embed_dim self.hparams = hparams if self.hparams.is_multi_speakers: self.spk_embed_integration_type = spk_embed_integration_type # use idx 0 as padding idx self.padding_idx = 0 # get positional encoding class pos_enc_class = (ScaledPositionalEncoding if self.use_scaled_pos_enc else PositionalEncoding) # define encoder # print(idim) encoder_input_layer = torch.nn.Embedding(num_embeddings=idim, embedding_dim=adim, padding_idx=self.padding_idx) if self.hparams.is_multi_speakers: self.speaker_embedding = torch.nn.Embedding( hparams.n_speakers, self.spk_embed_dim) std = sqrt(2.0 / (hparams.n_speakers + self.spk_embed_dim)) val = sqrt(3.0) * std # uniform bounds for std self.speaker_embedding.weight.data.uniform_(-val, val) self.spkemb_projection = torch.nn.Linear(hparams.spk_embed_dim, hparams.spk_embed_dim) self.encoder = Encoder( idim=idim, attention_dim=adim, attention_heads=aheads, linear_units=eunits, num_blocks=elayers, input_layer=encoder_input_layer, dropout_rate=transformer_enc_dropout_rate, positional_dropout_rate=transformer_enc_positional_dropout_rate, attention_dropout_rate=transformer_enc_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=encoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=encoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define GST if self.use_gst: self.gst = StyleEncoder( idim=odim, # the input is mel-spectrogram gst_tokens=gst_tokens, gst_token_dim=adim, gst_heads=gst_heads, conv_layers=gst_conv_layers, conv_chans_list=gst_conv_chans_list, conv_kernel_size=gst_conv_kernel_size, conv_stride=gst_conv_stride, gru_layers=gst_gru_layers, gru_units=gst_gru_units, hparams=hparams) if self.hparams.style_embed_integration_type == "concat": self.gst_projection = torch.nn.Linear(adim + adim, adim) # define additional projection for speaker embedding if self.hparams.is_multi_speakers: if self.spk_embed_integration_type == "add": self.projection = torch.nn.Linear(self.spk_embed_dim, adim) else: self.projection = torch.nn.Linear(adim + self.spk_embed_dim, adim) # define duration predictor self.duration_predictor = DurationPredictor( idim=adim, n_layers=duration_predictor_layers, n_chans=duration_predictor_chans, kernel_size=duration_predictor_kernel_size, dropout_rate=duration_predictor_dropout_rate, hparams=hparams) # define length regulator self.length_regulator = LengthRegulator() # define decoder # NOTE: we use encoder as decoder # because fastspeech's decoder is the same as encoder self.decoder = Encoder( idim=0, attention_dim=adim, attention_heads=aheads, linear_units=dunits, num_blocks=dlayers, input_layer=None, dropout_rate=transformer_dec_dropout_rate, positional_dropout_rate=transformer_dec_positional_dropout_rate, attention_dropout_rate=transformer_dec_attn_dropout_rate, pos_enc_class=pos_enc_class, normalize_before=decoder_normalize_before, is_spk_layer_norm=is_spk_layer_norm, concat_after=decoder_concat_after, positionwise_layer_type=positionwise_layer_type, positionwise_conv_kernel_size=positionwise_conv_kernel_size, hparams=hparams) # define final projection self.feat_out = torch.nn.Linear(adim, odim * reduction_factor) # define postnet self.postnet = (None if postnet_layers == 0 else Postnet( idim=idim, odim=odim, n_layers=postnet_layers, n_chans=postnet_chans, n_filts=postnet_filts, use_batch_norm=use_batch_norm, dropout_rate=postnet_dropout_rate, )) # initialize parameters self._reset_parameters( init_type=init_type, init_enc_alpha=init_enc_alpha, init_dec_alpha=init_dec_alpha, ) # define criterions self.criterion = FastSpeechLoss( use_masking=use_masking, use_weighted_masking=use_weighted_masking) def _forward( self, xs: torch.Tensor, ilens: torch.Tensor, ys: torch.Tensor = None, olens: torch.Tensor = None, ds: torch.Tensor = None, spk_ids: torch.Tensor = None, style_ids: torch.Tensor = None, utt_mels: list = None, is_inference: bool = False, alpha: float = 1.0, ) -> Sequence[torch.Tensor]: # integrate speaker embedding if self.hparams.is_multi_speakers: # print('spk_ids.shape=',spk_ids.shape) spembs_ = self.spkemb_projection(self.speaker_embedding(spk_ids)) # forward encoder x_masks = self._source_mask(ilens) # print(spembs_) hs, _ = self.encoder(xs, x_masks, spembs_=spembs_) # (B, Tmax, adim) hs = self._integrate_with_spk_embed( hs, spembs_) if self.hparams.is_multi_speakers else hs if self.use_gst: if self.hparams.is_partial_refine and self.hparams.is_refine_style: style_embs = [] for i in range(len(utt_mels)): style_embs.append( self.gst(to_gpu(utt_mels[i]), spembs_=spembs_)) #(1, gst_token_dim) style_embs = torch.cat(style_embs, dim=0) #(17, gst_token_dim) style_tokens = style_embs.unsqueeze(0).expand( hs.size(0), -1, -1) #(B, 17, gst_token_dim) style_embs = self.gst.choosestl( hs, style_tokens) #(B, Tx, gst_token_dim) else: style_embs = self.gst(ys, spembs_=spembs_) # integrate with GST if self.hparams.style_embed_integration_type == "concat": hs = self.gst_projection( torch.cat([ hs, style_embs.unsqueeze(1).expand(-1, hs.size(1), -1) ], dim=-1)) print('spembs_.shape=', spembs_.shape) else: hs = hs + style_embs.unsqueeze(1) # forward duration predictor and length regulator d_masks = make_pad_mask(ilens).to(xs.device) if is_inference: d_outs = self.duration_predictor.inference( hs, d_masks, spembs_=spembs_) # (B, Tmax) hs = self.length_regulator(hs, d_outs, ilens, alpha) # (B, Lmax, adim) else: d_outs = self.duration_predictor(hs, d_masks, spembs_=spembs_) # (B, Tmax) hs = self.length_regulator(hs, ds, ilens) # (B, Lmax, adim) # forward decoder if olens is not None and not is_inference: if self.reduction_factor > 1: olens_in = olens.new( [olen // self.reduction_factor for olen in olens]) else: olens_in = olens h_masks = self._source_mask(olens_in) else: h_masks = None zs, _ = self.decoder(hs, h_masks, spembs_=spembs_) # (B, Lmax, adim) before_outs = self.feat_out(zs).view(zs.size(0), -1, self.odim) # (B, Lmax, odim) # postnet -> (B, Lmax//r * r, odim) if self.postnet is None: after_outs = before_outs else: after_outs = before_outs + self.postnet(before_outs.transpose( 1, 2)).transpose(1, 2) return before_outs, after_outs, d_outs def forward( self, text: torch.Tensor, text_lengths: torch.Tensor, speech: torch.Tensor, speech_lengths: torch.Tensor, durations: torch.Tensor, durations_lengths: torch.Tensor, spk_ids: torch.Tensor = None, style_ids: torch.Tensor = None, utt_mels: list = None ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: """Calculate forward propagation. Args: text (LongTensor): Batch of padded character ids (B, Tmax). text_lengths (LongTensor): Batch of lengths of each input (B,). speech (Tensor): Batch of padded target features (B, Lmax, odim). speech_lengths (LongTensor): Batch of the lengths of each target (B,). durations (LongTensor): Batch of padded durations (B, Tmax). durations_lengths (LongTensor): Batch of duration lengths (B,). spembs (Tensor, optional): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Loss scalar value. Dict: Statistics to be monitored. Tensor: Weight value. """ text = text[:, :text_lengths.max()] # for data-parallel speech = speech[:, :speech_lengths.max()] # for data-parallel durations = durations[:, :durations_lengths.max()] # for data-parallel batch_size = text.size(0) xs = text ilens = text_lengths ys, ds = speech, durations olens = speech_lengths # forward propagation before_outs, after_outs, d_outs = self._forward(xs, ilens, ys, olens, ds, spk_ids=spk_ids, style_ids=style_ids, utt_mels=utt_mels, is_inference=False) # modifiy mod part of groundtruth if self.reduction_factor > 1: olens = olens.new( [olen - olen % self.reduction_factor for olen in olens]) max_olen = max(olens) ys = ys[:, :max_olen] # calculate loss if self.postnet is None: after_outs = None l1_loss, duration_loss, l2_loss = self.criterion( after_outs, before_outs, d_outs, ys, ds, ilens, olens) if self.hparams.use_ssim_loss: import pytorch_ssim ssim_loss = pytorch_ssim.SSIM() ssim_out = 3.0 * (2.0 - ssim_loss(before_outs, after_outs, ys, olens)) if self.hparams.loss_type == "L1": loss = l1_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) if self.hparams.loss_type == "L2": loss = l2_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) if self.hparams.loss_type == "L1_L2": loss = l1_loss + duration_loss + ( ssim_out if self.hparams.use_ssim_loss else 0) + l2_loss stats = dict( L1=l1_loss.item() if self.hparams.loss_type == "L1" else 0, L2=l2_loss.item() if self.hparams.loss_type == "L2" else 0, L1_L2=l1_loss.item() + l2_loss.item() if self.hparams.loss_type == "L1_L2" else 0, duration_loss=duration_loss.item(), loss=loss.item(), ssim_loss=ssim_out.item() if self.hparams.use_ssim_loss else 0, ) # report extra information if self.use_scaled_pos_enc: stats.update( encoder_alpha=self.encoder.embed[-1].alpha.data.item(), decoder_alpha=self.decoder.embed[-1].alpha.data.item(), ) loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) return loss, stats, weight, after_outs, ys, olens def inference( self, text: torch.Tensor, speech: torch.Tensor = None, spk_ids: torch.Tensor = None, durations: torch.Tensor = None, alpha: float = 1.0, use_teacher_forcing: bool = False, utt_mels: list = None ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """Generate the sequence of features given the sequences of characters. Args: text (LongTensor): Input sequence of characters (T,). speech (Tensor, optional): Feature sequence to extract style (N, idim). spk_ids (Tensor, optional): Speaker embedding vector (spk_embed_dim,). durations (LongTensor, optional): Groundtruth of duration (T + 1,). alpha (float, optional): Alpha to control the speed. use_teacher_forcing (bool, optional): Whether to use teacher forcing. If true, groundtruth of duration, pitch and energy will be used. Returns: Tensor: Output sequence of features (L, odim). None: Dummy for compatibility. None: Dummy for compatibility. """ x, y = text, speech spk_ids, d = spk_ids, durations # setup batch axis ilens = torch.tensor([x.shape[0]], dtype=torch.long, device=x.device) xs, ys = x.unsqueeze(0), None # if y is not None: # ys = y.unsqueeze(0) # if spk_ids is not None: # spk_ids = spk_ids.unsqueeze(0) if use_teacher_forcing: # use groundtruth of duration, pitch, and energy ds = d.unsqueeze(0) _, outs, *_ = self._forward( xs, ilens, ys, ds=ds, spk_ids=spk_ids, ) # (1, L, odim) else: # inference _, outs, _ = self._forward(xs, ilens, ys, spk_ids=spk_ids, is_inference=True, alpha=alpha, utt_mels=utt_mels) # (1, L, odim) return outs[0], None, None def _integrate_with_spk_embed(self, hs: torch.Tensor, spembs: torch.Tensor) -> torch.Tensor: """Integrate speaker embedding with hidden states. Args: hs (Tensor): Batch of hidden state sequences (B, Tmax, adim). spembs (Tensor): Batch of speaker embeddings (B, spk_embed_dim). Returns: Tensor: Batch of integrated hidden state sequences (B, Tmax, adim). """ if self.spk_embed_integration_type == "add": # apply projection and then add to hidden states spembs = self.projection(F.normalize(spembs)) hs = hs + spembs.unsqueeze(1) elif self.spk_embed_integration_type == "concat": # concat hidden states with spk embeds and then apply projection spembs = F.normalize(spembs).unsqueeze(1).expand( -1, hs.size(1), -1) hs = self.projection(torch.cat([hs, spembs], dim=-1)) else: raise NotImplementedError("support only add or concat.") return hs def _source_mask(self, ilens: torch.Tensor) -> torch.Tensor: """Make masks for self-attention. Args: ilens (LongTensor): Batch of lengths (B,). Returns: Tensor: Mask tensor for self-attention. dtype=torch.uint8 in PyTorch 1.2- dtype=torch.bool in PyTorch 1.2+ (including 1.2) Examples: >>> ilens = [5, 3] >>> self._source_mask(ilens) tensor([[[1, 1, 1, 1, 1], [1, 1, 1, 0, 0]]], dtype=torch.uint8) """ x_masks = make_non_pad_mask(ilens).to(next(self.parameters()).device) return x_masks.unsqueeze(-2) def _reset_parameters(self, init_type: str, init_enc_alpha: float, init_dec_alpha: float): # initialize parameters if init_type != "pytorch": initialize(self, init_type) # initialize alpha in scaled positional encoding if self.use_scaled_pos_enc: self.encoder.embed[-1].alpha.data = torch.tensor(init_enc_alpha) self.decoder.embed[-1].alpha.data = torch.tensor(init_dec_alpha) @staticmethod def _parse_batch(batch, hparams, utt_mels=None): text_padded, input_lengths, mel_padded, output_lengths, duration_padded, duration_lengths = batch[: 6] text_padded = to_gpu(text_padded).long() input_lengths = to_gpu(input_lengths).long() mel_padded = to_gpu(mel_padded).float() output_lengths = to_gpu(output_lengths).long() duration_padded = to_gpu(duration_padded).long() duration_lengths = to_gpu(duration_lengths).long() idx = 6 speaker_ids = None style_ids = None utt_mels = utt_mels if hparams.is_multi_speakers: speaker_ids = batch[idx] speaker_ids = to_gpu(speaker_ids).long() idx += 1 if hparams.is_multi_styles: style_ids = batch[idx] style_ids = to_gpu(style_ids).long() #(B,) idx += 1 return (text_padded, input_lengths, mel_padded, output_lengths, duration_padded, duration_lengths, speaker_ids, style_ids, utt_mels)