Пример #1
0
from malaya_speech.train.model import melgan, mb_melgan, stft
from malaya_speech.train.loss import calculate_2d_loss, calculate_3d_loss

mb_melgan_config = malaya_speech.config.mb_melgan_config
generator = melgan.Generator(
    mb_melgan.GeneratorConfig(**mb_melgan_config['melgan_generator_params']),
    name='mb_melgan-generator',
)
pqmf = mb_melgan.PQMF(
    mb_melgan.GeneratorConfig(**mb_melgan_config['melgan_generator_params']),
    dtype=tf.float32,
    name='pqmf',
)
discriminator = melgan.MultiScaleDiscriminator(
    mb_melgan.DiscriminatorConfig(
        **mb_melgan_config['melgan_discriminator_params']),
    name='melgan-discriminator',
)

mels_loss = melgan.loss.TFMelSpectrogram()

mse_loss = tf.keras.losses.MeanSquaredError()
mae_loss = tf.keras.losses.MeanAbsoluteError()

sub_band_stft_loss = stft.loss.MultiResolutionSTFT(
    **mb_melgan_config['subband_stft_loss_params'])

full_band_stft_loss = stft.loss.MultiResolutionSTFT(
    **mb_melgan_config['stft_loss_params'])

Пример #2
0
import malaya_speech.config
from malaya_speech.train.loss import calculate_2d_loss, calculate_3d_loss

hifigan_config = malaya_speech.config.hifigan_config
generator = hifigan.Generator(
    hifigan.GeneratorConfig(**hifigan_config['hifigan_generator_params']),
    name='hifigan_generator',
)
multiperiod_discriminator = hifigan.MultiPeriodDiscriminator(
    hifigan.DiscriminatorConfig(
        **hifigan_config['hifigan_discriminator_params']),
    name='hifigan_multiperiod_discriminator',
)
multiscale_discriminator = melgan.MultiScaleDiscriminator(
    melgan.DiscriminatorConfig(
        **hifigan_config['melgan_discriminator_params'],
        name='melgan_multiscale_discriminator',
    ))
discriminator = hifigan.Discriminator(multiperiod_discriminator,
                                      multiscale_discriminator)

stft_loss = stft.loss.MultiResolutionSTFT(**hifigan_config['stft_loss_params'])
mels_loss = melgan.loss.TFMelSpectrogram()
mse_loss = tf.keras.losses.MeanSquaredError()
mae_loss = tf.keras.losses.MeanAbsoluteError()


def compute_per_example_generator_losses(features):
    y_hat = generator(features['mel'], training=True)
    audios = features['audio']