from malaya_speech.train.model import melgan, mb_melgan, stft from malaya_speech.train.loss import calculate_2d_loss, calculate_3d_loss mb_melgan_config = malaya_speech.config.mb_melgan_config generator = melgan.Generator( mb_melgan.GeneratorConfig(**mb_melgan_config['melgan_generator_params']), name='mb_melgan-generator', ) pqmf = mb_melgan.PQMF( mb_melgan.GeneratorConfig(**mb_melgan_config['melgan_generator_params']), dtype=tf.float32, name='pqmf', ) discriminator = melgan.MultiScaleDiscriminator( mb_melgan.DiscriminatorConfig( **mb_melgan_config['melgan_discriminator_params']), name='melgan-discriminator', ) mels_loss = melgan.loss.TFMelSpectrogram() mse_loss = tf.keras.losses.MeanSquaredError() mae_loss = tf.keras.losses.MeanAbsoluteError() sub_band_stft_loss = stft.loss.MultiResolutionSTFT( **mb_melgan_config['subband_stft_loss_params']) full_band_stft_loss = stft.loss.MultiResolutionSTFT( **mb_melgan_config['stft_loss_params'])
import malaya_speech.config from malaya_speech.train.loss import calculate_2d_loss, calculate_3d_loss hifigan_config = malaya_speech.config.hifigan_config generator = hifigan.Generator( hifigan.GeneratorConfig(**hifigan_config['hifigan_generator_params']), name='hifigan_generator', ) multiperiod_discriminator = hifigan.MultiPeriodDiscriminator( hifigan.DiscriminatorConfig( **hifigan_config['hifigan_discriminator_params']), name='hifigan_multiperiod_discriminator', ) multiscale_discriminator = melgan.MultiScaleDiscriminator( melgan.DiscriminatorConfig( **hifigan_config['melgan_discriminator_params'], name='melgan_multiscale_discriminator', )) discriminator = hifigan.Discriminator(multiperiod_discriminator, multiscale_discriminator) stft_loss = stft.loss.MultiResolutionSTFT(**hifigan_config['stft_loss_params']) mels_loss = melgan.loss.TFMelSpectrogram() mse_loss = tf.keras.losses.MeanSquaredError() mae_loss = tf.keras.losses.MeanAbsoluteError() def compute_per_example_generator_losses(features): y_hat = generator(features['mel'], training=True) audios = features['audio']