def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn(batch_size, args_g["in_channels"], batch_length // np.prod( args_g["upsample_scales"])) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) y, y_hat = y.squeeze(1), y_hat.squeeze(1) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = 0.0 for i in range(len(p_hat)): adv_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_ones(p_hat[i][-1].size())) adv_loss /= (i + 1) with torch.no_grad(): p = model_d(y.unsqueeze(1)) fm_loss = 0.0 for i in range(len(p_hat)): for j in range(len(p_hat[i]) - 1): fm_loss += F.l1_loss(p_hat[i][j], p[i][j].detach()) fm_loss /= (i + 1) * j loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) real_loss = 0.0 fake_loss = 0.0 for i in range(len(p)): real_loss += F.mse_loss( p[i][-1], p[i][-1].new_ones(p[i][-1].size())) fake_loss += F.mse_loss( p_hat[i][-1], p_hat[i][-1].new_zeros(p_hat[i][-1].size())) real_loss /= (i + 1) fake_loss /= (i + 1) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_causal_melgan(dict_g): batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) upsampling_factor = np.prod(args_g["upsample_scales"]) c = torch.randn(batch_size, args_g["in_channels"], batch_length // upsampling_factor) model_g = MelGANGenerator(**args_g) c_ = c.clone() c_[..., c.size(-1) // 2:] = torch.randn(c[..., c.size(-1) // 2:].shape) try: # check not equal np.testing.assert_array_equal(c.numpy(), c_.numpy()) except AssertionError: pass else: raise AssertionError("Must be different.") # check causality y = model_g(c) y_ = model_g(c_) assert y.size(2) == c.size(2) * upsampling_factor np.testing.assert_array_equal( y[..., :c.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), y_[..., :c_.size(-1) // 2 * upsampling_factor].detach().cpu().numpy(), )
def test_parallel_wavegan_compatibility(): from parallel_wavegan.models import MelGANGenerator as PWGMelGANGenerator model_pwg = PWGMelGANGenerator(**make_melgan_generator_args()) model_espnet2 = MelGANGenerator(**make_melgan_generator_args()) model_espnet2.load_state_dict(model_pwg.state_dict()) model_pwg.eval() model_espnet2.eval() with torch.no_grad(): c = torch.randn(5, 80) out_pwg = model_pwg.inference(c) out_espnet2 = model_espnet2.inference(c) np.testing.assert_array_equal( out_pwg.cpu().numpy(), out_espnet2.cpu().numpy(), )
def test_melgan_trainable_with_melgan_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_melgan_generator_args(**dict_g) args_d = make_melgan_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn( batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"]), ) model_g = MelGANGenerator(**args_g) model_d = MelGANMultiScaleDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) feat_match_criterion = FeatureMatchLoss() gen_adv_criterion = GeneratorAdversarialLoss() dis_adv_criterion = DiscriminatorAdversarialLoss() optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss adv_loss = gen_adv_criterion(p_hat) with torch.no_grad(): p = model_d(y) fm_loss = feat_match_criterion(p_hat, p) loss_g = adv_loss + aux_loss + fm_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable p = model_d(y) p_hat = model_d(y_hat.detach()) real_loss, fake_loss = dis_adv_criterion(p_hat, p) loss_d = real_loss + fake_loss optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
def test_melgan_trainable_with_residual_discriminator(dict_g, dict_d, dict_loss): # setup batch_size = 4 batch_length = 4096 args_g = make_generator_args(**dict_g) args_d = make_residual_discriminator_args(**dict_d) args_loss = make_mutli_reso_stft_loss_args(**dict_loss) y = torch.randn(batch_size, 1, batch_length) c = torch.randn(batch_size, args_g["in_channels"], batch_length // np.prod(args_g["upsample_scales"])) model_g = MelGANGenerator(**args_g) model_d = ResidualParallelWaveGANDiscriminator(**args_d) aux_criterion = MultiResolutionSTFTLoss(**args_loss) optimizer_g = RAdam(model_g.parameters()) optimizer_d = RAdam(model_d.parameters()) # check generator trainable y_hat = model_g(c) p_hat = model_d(y_hat) y, y_hat, p_hat = y.squeeze(1), y_hat.squeeze(1), p_hat.squeeze(1) adv_loss = F.mse_loss(p_hat, p_hat.new_ones(p_hat.size())) sc_loss, mag_loss = aux_criterion(y_hat, y) aux_loss = sc_loss + mag_loss loss_g = adv_loss + aux_loss optimizer_g.zero_grad() loss_g.backward() optimizer_g.step() # check discriminator trainable y, y_hat = y.unsqueeze(1), y_hat.unsqueeze(1).detach() p = model_d(y) p_hat = model_d(y_hat) p, p_hat = p.squeeze(1), p_hat.squeeze(1) loss_d = F.mse_loss(p, p.new_ones(p.size())) + F.mse_loss( p_hat, p_hat.new_zeros(p_hat.size())) optimizer_d.zero_grad() loss_d.backward() optimizer_d.step()
# Following instructions from notebook # pip install -qq . # pip install -qq tensorflow-gpu==2.1 import os import numpy as np import torch import tensorflow as tf from tensorflow.python.framework import convert_to_constants from tensorflow.python.saved_model import signature_constants, tag_constants import yaml from parallel_wavegan.models import MelGANGenerator # setup pytorch model vocoder_conf = './egs/ljspeech/voc1/conf/melgan.v1.long.yaml' with open(vocoder_conf) as f: config = yaml.load(f, Loader=yaml.Loader) pytorch_melgan = MelGANGenerator(**config["generator_params"]) pytorch_melgan.remove_weight_norm() # TODO: Train MelGAN (Save state_dict to checkpoint and time to train) pytorch_melgan = pytorch_melgan.to("cuda").eval() # checks inference speed fake_mels = np.random.sample((4, 1500, 80)).astype(np.float32) with torch.no_grad(): y = pytorch_melgan(fake_mels) # TODO: check LJSpeech inference speed and MCD metric