def test_hifigan_generator_and_discriminator_and_loss( dict_g, dict_d, dict_loss, average, include ): batch_size = 2 batch_length = 128 args_g = make_hifigan_generator_args(**dict_g) args_d = make_hifigan_multi_scale_multi_period_discriminator_args(**dict_d) args_loss = make_mel_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) g = None if args_g.get("global_channels") is not None: g = torch.randn(batch_size, args_g["global_channels"], 1) model_g = HiFiGANGenerator(**args_g) model_d = HiFiGANMultiScaleMultiPeriodDiscriminator(**args_d) aux_criterion = MelSpectrogramLoss(**args_loss) feat_match_criterion = FeatureMatchLoss( average_by_layers=average, average_by_discriminators=average, include_final_outputs=include, ) gen_adv_criterion = GeneratorAdversarialLoss( average_by_discriminators=average, ) dis_adv_criterion = DiscriminatorAdversarialLoss( average_by_discriminators=average, ) optimizer_g = torch.optim.AdamW(model_g.parameters()) optimizer_d = torch.optim.AdamW(model_d.parameters()) # check generator trainable y_hat = model_g(c, g=g) p_hat = model_d(y_hat) aux_loss = aux_criterion(y_hat, y) adv_loss = gen_adv_criterion(p_hat) with torch.no_grad(): p = model_d(y) fm_loss = feat_match_criterion(p_hat, p) loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def __init__( self, # generator (text2mel + vocoder) related idim: int, odim: int, segment_size: int = 32, sampling_rate: int = 22050, text2mel_type: str = "fastspeech2", text2mel_params: Dict[str, Any] = { "adim": 384, "aheads": 2, "elayers": 4, "eunits": 1536, "dlayers": 4, "dunits": 1536, "postnet_layers": 5, "postnet_chans": 512, "postnet_filts": 5, "postnet_dropout_rate": 0.5, "positionwise_layer_type": "conv1d", "positionwise_conv_kernel_size": 1, "use_scaled_pos_enc": True, "use_batch_norm": True, "encoder_normalize_before": True, "decoder_normalize_before": True, "encoder_concat_after": False, "decoder_concat_after": False, "reduction_factor": 1, "encoder_type": "conformer", "decoder_type": "conformer", "transformer_enc_dropout_rate": 0.1, "transformer_enc_positional_dropout_rate": 0.1, "transformer_enc_attn_dropout_rate": 0.1, "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, "conformer_rel_pos_type": "latest", "conformer_pos_enc_layer_type": "rel_pos", "conformer_self_attn_layer_type": "rel_selfattn", "conformer_activation_type": "swish", "use_macaron_style_in_conformer": True, "use_cnn_in_conformer": True, "zero_triu": False, "conformer_enc_kernel_size": 7, "conformer_dec_kernel_size": 31, "duration_predictor_layers": 2, "duration_predictor_chans": 384, "duration_predictor_kernel_size": 3, "duration_predictor_dropout_rate": 0.1, "energy_predictor_layers": 2, "energy_predictor_chans": 384, "energy_predictor_kernel_size": 3, "energy_predictor_dropout": 0.5, "energy_embed_kernel_size": 1, "energy_embed_dropout": 0.5, "stop_gradient_from_energy_predictor": False, "pitch_predictor_layers": 5, "pitch_predictor_chans": 384, "pitch_predictor_kernel_size": 5, "pitch_predictor_dropout": 0.5, "pitch_embed_kernel_size": 1, "pitch_embed_dropout": 0.5, "stop_gradient_from_pitch_predictor": True, "spks": -1, "langs": -1, "spk_embed_dim": None, "spk_embed_integration_type": "add", "use_gst": False, "gst_tokens": 10, "gst_heads": 4, "gst_conv_layers": 6, "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], "gst_conv_kernel_size": 3, "gst_conv_stride": 2, "gst_gru_layers": 1, "gst_gru_units": 128, "init_type": "xavier_uniform", "init_enc_alpha": 1.0, "init_dec_alpha": 1.0, "use_masking": False, "use_weighted_masking": False, }, vocoder_type: str = "hifigan_generator", vocoder_params: Dict[str, Any] = { "out_channels": 1, "channels": 512, "global_channels": -1, "kernel_size": 7, "upsample_scales": [8, 8, 2, 2], "upsample_kernel_sizes": [16, 16, 4, 4], "resblock_kernel_sizes": [3, 7, 11], "resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "use_additional_convs": True, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, }, use_pqmf: bool = False, pqmf_params: Dict[str, Any] = { "subbands": 4, "taps": 62, "cutoff_ratio": 0.142, "beta": 9.0, }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any] = { "scales": 1, "scale_downsample_pooling": "AvgPool1d", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": { "negative_slope": 0.1 }, "use_weight_norm": True, "use_spectral_norm": False, }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, discriminator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, use_feat_match_loss: bool = True, feat_match_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "average_by_layers": False, "include_final_outputs": True, }, use_mel_loss: bool = True, mel_loss_params: Dict[str, Any] = { "fs": 22050, "n_fft": 1024, "hop_length": 256, "win_length": None, "window": "hann", "n_mels": 80, "fmin": 0, "fmax": None, "log_base": None, }, lambda_text2mel: float = 1.0, lambda_adv: float = 1.0, lambda_feat_match: float = 2.0, lambda_mel: float = 45.0, cache_generator_outputs: bool = False, ): """Initialize JointText2Wav module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since the model is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. segment_size (int): Segment size for random windowed inputs. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. text2mel_type (str): The text2mel model type. text2mel_params (Dict[str, Any]): Parameter dict for text2mel model. use_pqmf (bool): Whether to use PQMF for multi-band vocoder. pqmf_params (Dict[str, Any]): Parameter dict for PQMF module. vocoder_type (str): The vocoder model type. vocoder_params (Dict[str, Any]): Parameter dict for vocoder model. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator adversarial loss. discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for discriminator adversarial loss. use_feat_match_loss (bool): Whether to use feat match loss. feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. use_mel_loss (bool): Whether to use mel loss. mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. lambda_text2mel (float): Loss scaling coefficient for text2mel model loss. lambda_adv (float): Loss scaling coefficient for adversarial loss. lambda_feat_match (float): Loss scaling coefficient for feat match loss. lambda_mel (float): Loss scaling coefficient for mel loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() self.segment_size = segment_size self.use_pqmf = use_pqmf # define modules self.generator = torch.nn.ModuleDict() text2mel_class = AVAILABLE_TEXT2MEL[text2mel_type] text2mel_params.update(idim=idim, odim=odim) self.generator["text2mel"] = text2mel_class(**text2mel_params, ) vocoder_class = AVAILABLE_VOCODER[vocoder_type] if vocoder_type == "hifigan_generator": vocoder_params.update(in_channels=odim) elif vocoder_type == "parallel_wavegan_generator": vocoder_params.update(aux_channels=odim) self.generator["vocoder"] = vocoder_class(**vocoder_params, ) if self.use_pqmf: self.pqmf = PQMF(**pqmf_params) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class(**discriminator_params, ) self.generator_adv_loss = GeneratorAdversarialLoss( **generator_adv_loss_params, ) self.discriminator_adv_loss = DiscriminatorAdversarialLoss( **discriminator_adv_loss_params, ) self.use_feat_match_loss = use_feat_match_loss if self.use_feat_match_loss: self.feat_match_loss = FeatureMatchLoss(**feat_match_loss_params, ) self.use_mel_loss = use_mel_loss if self.use_mel_loss: self.mel_loss = MelSpectrogramLoss(**mel_loss_params, ) # coefficients self.lambda_text2mel = lambda_text2mel self.lambda_adv = lambda_adv if self.use_feat_match_loss: self.lambda_feat_match = lambda_feat_match if self.use_mel_loss: self.lambda_mel = lambda_mel # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate
def __init__( self, # generator related idim: int, odim: int, sampling_rate: int = 22050, generator_type: str = "vits_generator", generator_params: Dict[str, Any] = { "hidden_channels": 192, "spks": None, "langs": None, "spk_embed_dim": None, "global_channels": -1, "segment_size": 32, "text_encoder_attention_heads": 2, "text_encoder_ffn_expand": 4, "text_encoder_blocks": 6, "text_encoder_positionwise_layer_type": "conv1d", "text_encoder_positionwise_conv_kernel_size": 1, "text_encoder_positional_encoding_layer_type": "rel_pos", "text_encoder_self_attention_layer_type": "rel_selfattn", "text_encoder_activation_type": "swish", "text_encoder_normalize_before": True, "text_encoder_dropout_rate": 0.1, "text_encoder_positional_dropout_rate": 0.0, "text_encoder_attention_dropout_rate": 0.0, "text_encoder_conformer_kernel_size": 7, "use_macaron_style_in_text_encoder": True, "use_conformer_conv_in_text_encoder": True, "decoder_kernel_size": 7, "decoder_channels": 512, "decoder_upsample_scales": [8, 8, 2, 2], "decoder_upsample_kernel_sizes": [16, 16, 4, 4], "decoder_resblock_kernel_sizes": [3, 7, 11], "decoder_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "use_weight_norm_in_decoder": True, "posterior_encoder_kernel_size": 5, "posterior_encoder_layers": 16, "posterior_encoder_stacks": 1, "posterior_encoder_base_dilation": 1, "posterior_encoder_dropout_rate": 0.0, "use_weight_norm_in_posterior_encoder": True, "flow_flows": 4, "flow_kernel_size": 5, "flow_base_dilation": 1, "flow_layers": 4, "flow_dropout_rate": 0.0, "use_weight_norm_in_flow": True, "use_only_mean_in_flow": True, "stochastic_duration_predictor_kernel_size": 3, "stochastic_duration_predictor_dropout_rate": 0.5, "stochastic_duration_predictor_flows": 4, "stochastic_duration_predictor_dds_conv_layers": 3, }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any] = { "scales": 1, "scale_downsample_pooling": "AvgPool1d", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, discriminator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, feat_match_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "average_by_layers": False, "include_final_outputs": True, }, mel_loss_params: Dict[str, Any] = { "fs": 22050, "n_fft": 1024, "hop_length": 256, "win_length": None, "window": "hann", "n_mels": 80, "fmin": 0, "fmax": None, "log_base": None, }, lambda_adv: float = 1.0, lambda_mel: float = 45.0, lambda_feat_match: float = 2.0, lambda_dur: float = 1.0, lambda_kl: float = 1.0, cache_generator_outputs: bool = True, ): """Initialize VITS module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since VITS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator adversarial loss. discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for discriminator adversarial loss. feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. lambda_adv (float): Loss scaling coefficient for adversarial loss. lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. lambda_feat_match (float): Loss scaling coefficient for feat match loss. lambda_dur (float): Loss scaling coefficient for duration loss. lambda_kl (float): Loss scaling coefficient for KL divergence loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() # define modules generator_class = AVAILABLE_GENERATERS[generator_type] if generator_type == "vits_generator": # NOTE(kan-bayashi): Update parameters for the compatibility. # The idim and odim is automatically decided from input data, # where idim represents #vocabularies and odim represents # the input acoustic feature dimension. generator_params.update(vocabs=idim, aux_channels=odim) self.generator = generator_class( **generator_params, ) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) self.generator_adv_loss = GeneratorAdversarialLoss( **generator_adv_loss_params, ) self.discriminator_adv_loss = DiscriminatorAdversarialLoss( **discriminator_adv_loss_params, ) self.feat_match_loss = FeatureMatchLoss( **feat_match_loss_params, ) self.mel_loss = MelSpectrogramLoss( **mel_loss_params, ) self.kl_loss = KLDivergenceLoss() # coefficients self.lambda_adv = lambda_adv self.lambda_mel = lambda_mel self.lambda_kl = lambda_kl self.lambda_feat_match = lambda_feat_match self.lambda_dur = lambda_dur # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator.spks self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim
def __init__( self, # generator related idim: int, odim: int, sampling_rate: int = 22050, generator_type: str = "jets_generator", generator_params: Dict[str, Any] = { "adim": 256, "aheads": 2, "elayers": 4, "eunits": 1024, "dlayers": 4, "dunits": 1024, "positionwise_layer_type": "conv1d", "positionwise_conv_kernel_size": 1, "use_scaled_pos_enc": True, "use_batch_norm": True, "encoder_normalize_before": True, "decoder_normalize_before": True, "encoder_concat_after": False, "decoder_concat_after": False, "reduction_factor": 1, "encoder_type": "transformer", "decoder_type": "transformer", "transformer_enc_dropout_rate": 0.1, "transformer_enc_positional_dropout_rate": 0.1, "transformer_enc_attn_dropout_rate": 0.1, "transformer_dec_dropout_rate": 0.1, "transformer_dec_positional_dropout_rate": 0.1, "transformer_dec_attn_dropout_rate": 0.1, "conformer_rel_pos_type": "latest", "conformer_pos_enc_layer_type": "rel_pos", "conformer_self_attn_layer_type": "rel_selfattn", "conformer_activation_type": "swish", "use_macaron_style_in_conformer": True, "use_cnn_in_conformer": True, "zero_triu": False, "conformer_enc_kernel_size": 7, "conformer_dec_kernel_size": 31, "duration_predictor_layers": 2, "duration_predictor_chans": 384, "duration_predictor_kernel_size": 3, "duration_predictor_dropout_rate": 0.1, "energy_predictor_layers": 2, "energy_predictor_chans": 384, "energy_predictor_kernel_size": 3, "energy_predictor_dropout": 0.5, "energy_embed_kernel_size": 1, "energy_embed_dropout": 0.5, "stop_gradient_from_energy_predictor": False, "pitch_predictor_layers": 5, "pitch_predictor_chans": 384, "pitch_predictor_kernel_size": 5, "pitch_predictor_dropout": 0.5, "pitch_embed_kernel_size": 1, "pitch_embed_dropout": 0.5, "stop_gradient_from_pitch_predictor": True, "generator_out_channels": 1, "generator_channels": 512, "generator_global_channels": -1, "generator_kernel_size": 7, "generator_upsample_scales": [8, 8, 2, 2], "generator_upsample_kernel_sizes": [16, 16, 4, 4], "generator_resblock_kernel_sizes": [3, 7, 11], "generator_resblock_dilations": [[1, 3, 5], [1, 3, 5], [1, 3, 5]], "generator_use_additional_convs": True, "generator_bias": True, "generator_nonlinear_activation": "LeakyReLU", "generator_nonlinear_activation_params": {"negative_slope": 0.1}, "generator_use_weight_norm": True, "segment_size": 64, "spks": -1, "langs": -1, "spk_embed_dim": None, "spk_embed_integration_type": "add", "use_gst": False, "gst_tokens": 10, "gst_heads": 4, "gst_conv_layers": 6, "gst_conv_chans_list": [32, 32, 64, 64, 128, 128], "gst_conv_kernel_size": 3, "gst_conv_stride": 2, "gst_gru_layers": 1, "gst_gru_units": 128, "init_type": "xavier_uniform", "init_enc_alpha": 1.0, "init_dec_alpha": 1.0, "use_masking": False, "use_weighted_masking": False, }, # discriminator related discriminator_type: str = "hifigan_multi_scale_multi_period_discriminator", discriminator_params: Dict[str, Any] = { "scales": 1, "scale_downsample_pooling": "AvgPool1d", "scale_downsample_pooling_params": { "kernel_size": 4, "stride": 2, "padding": 2, }, "scale_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [15, 41, 5, 3], "channels": 128, "max_downsample_channels": 1024, "max_groups": 16, "bias": True, "downsample_scales": [2, 2, 4, 4, 1], "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, "follow_official_norm": False, "periods": [2, 3, 5, 7, 11], "period_discriminator_params": { "in_channels": 1, "out_channels": 1, "kernel_sizes": [5, 3], "channels": 32, "downsample_scales": [3, 3, 3, 3, 1], "max_downsample_channels": 1024, "bias": True, "nonlinear_activation": "LeakyReLU", "nonlinear_activation_params": {"negative_slope": 0.1}, "use_weight_norm": True, "use_spectral_norm": False, }, }, # loss related generator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, discriminator_adv_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "loss_type": "mse", }, feat_match_loss_params: Dict[str, Any] = { "average_by_discriminators": False, "average_by_layers": False, "include_final_outputs": True, }, mel_loss_params: Dict[str, Any] = { "fs": 22050, "n_fft": 1024, "hop_length": 256, "win_length": None, "window": "hann", "n_mels": 80, "fmin": 0, "fmax": None, "log_base": None, }, lambda_adv: float = 1.0, lambda_mel: float = 45.0, lambda_feat_match: float = 2.0, lambda_var: float = 1.0, lambda_align: float = 2.0, cache_generator_outputs: bool = True, ): """Initialize JETS module. Args: idim (int): Input vocabrary size. odim (int): Acoustic feature dimension. The actual output channels will be 1 since JETS is the end-to-end text-to-wave model but for the compatibility odim is used to indicate the acoustic feature dimension. sampling_rate (int): Sampling rate, not used for the training but it will be referred in saving waveform during the inference. generator_type (str): Generator type. generator_params (Dict[str, Any]): Parameter dict for generator. discriminator_type (str): Discriminator type. discriminator_params (Dict[str, Any]): Parameter dict for discriminator. generator_adv_loss_params (Dict[str, Any]): Parameter dict for generator adversarial loss. discriminator_adv_loss_params (Dict[str, Any]): Parameter dict for discriminator adversarial loss. feat_match_loss_params (Dict[str, Any]): Parameter dict for feat match loss. mel_loss_params (Dict[str, Any]): Parameter dict for mel loss. lambda_adv (float): Loss scaling coefficient for adversarial loss. lambda_mel (float): Loss scaling coefficient for mel spectrogram loss. lambda_feat_match (float): Loss scaling coefficient for feat match loss. lambda_var (float): Loss scaling coefficient for variance loss. lambda_align (float): Loss scaling coefficient for alignment loss. cache_generator_outputs (bool): Whether to cache generator outputs. """ assert check_argument_types() super().__init__() # define modules generator_class = AVAILABLE_GENERATERS[generator_type] generator_params.update(idim=idim, odim=odim) self.generator = generator_class( **generator_params, ) discriminator_class = AVAILABLE_DISCRIMINATORS[discriminator_type] self.discriminator = discriminator_class( **discriminator_params, ) self.generator_adv_loss = GeneratorAdversarialLoss( **generator_adv_loss_params, ) self.discriminator_adv_loss = DiscriminatorAdversarialLoss( **discriminator_adv_loss_params, ) self.feat_match_loss = FeatureMatchLoss( **feat_match_loss_params, ) self.mel_loss = MelSpectrogramLoss( **mel_loss_params, ) self.var_loss = VarianceLoss() self.forwardsum_loss = ForwardSumLoss() # coefficients self.lambda_adv = lambda_adv self.lambda_mel = lambda_mel self.lambda_feat_match = lambda_feat_match self.lambda_var = lambda_var self.lambda_align = lambda_align # cache self.cache_generator_outputs = cache_generator_outputs self._cache = None # store sampling rate for saving wav file # (not used for the training) self.fs = sampling_rate # store parameters for test compatibility self.spks = self.generator.spks self.langs = self.generator.langs self.spk_embed_dim = self.generator.spk_embed_dim